def _test_pipeline(self, mode, params=None):
        """Helper function to test the full model pipeline.
    """
        # Create source and target example
        source_len = 10
        target_len = self.max_decode_length + 10
        source = " ".join(np.random.choice(self.vocab_list, source_len))
        target = " ".join(np.random.choice(self.vocab_list, target_len))
        sources_file, targets_file = test_utils.create_temp_parallel_data(
            sources=[source], targets=[target])

        # Build model graph
        model = self.create_model(params)
        data_provider = lambda: data_utils.make_parallel_data_provider(
            [sources_file.name], [targets_file.name])
        input_fn = training_utils.create_input_fn(data_provider,
                                                  self.batch_size)
        features, labels = input_fn()
        fetches = model(features, labels, None, mode)
        fetches = [_ for _ in fetches if _ is not None]

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(tf.tables_initializer())
            with tf.contrib.slim.queues.QueueRunners(sess):
                fetches_ = sess.run(fetches)

        sources_file.close()
        targets_file.close()

        return model, fetches_
Exemple #2
0
def create_inference_graph(model, input_pipeline, batch_size=32):
    """Creates a graph to perform inference.

  Args:
    task: An `InferenceTask` instance.
    input_pipeline: An instance of `InputPipeline` that defines
      how to read and parse data.
    batch_size: The batch size used for inference

  Returns:
    The return value of the model function, typically a tuple of
    (predictions, loss, train_op).
  """

    # TODO: This doesn't really belong here.
    # How to get rid of this?
    if hasattr(model, "use_beam_search"):
        if model.use_beam_search:
            tf.logging.info("Setting batch size to 1 for beam search.")
            batch_size = 1

    input_fn = training_utils.create_input_fn(pipeline=input_pipeline,
                                              batch_size=batch_size,
                                              allow_smaller_final_batch=True)

    # Build the graph
    features, labels = input_fn()
    return model(features=features, labels=labels, params=None)
Exemple #3
0
def create_inference_graph(model, input_pipeline, batch_size=32):
  """Creates a graph to perform inference.

  Args:
    task: An `InferenceTask` instance.
    input_pipeline: An instance of `InputPipeline` that defines
      how to read and parse data.
    batch_size: The batch size used for inference

  Returns:
    The return value of the model function, typically a tuple of
    (predictions, loss, train_op).
  """

  # TODO: This doesn't really belong here.
  # How to get rid of this?
  if hasattr(model, "use_beam_search"):
    if model.use_beam_search:
      tf.logging.info("Setting batch size to 1 for beam search.")
      batch_size = 1

  input_fn = training_utils.create_input_fn(
      pipeline=input_pipeline,
      batch_size=batch_size,
      allow_smaller_final_batch=True)

  # Build the graph
  features, labels = input_fn()
  return model(features=features, labels=labels, params=None)
def create_inference_graph(
    model_dir,
    input_file,
    batch_size=32,
    beam_width=None):
  """Creates a graph to perform inference.

  Args:
    model_dir: The output directory passed during training. This
      directory must contain model checkpoints.
    input_file: A source input file to read from.
    batch_size: The batch size used for inference
    beam_width: The beam width for beam search. If None,
      no beam search is used.

  Returns:
    The return value of the model functions, typically a tuple of
    (predictions, loss, train_op).
  """

  params_overrides = {}
  if beam_width is not None:
    tf.logging.info("Setting batch size to 1 for beam search.")
    batch_size = 1
    params_overrides["inference.beam_search.beam_width"] = beam_width

  model = load_model(model_dir)

  data_provider = lambda: data_utils.make_parallel_data_provider(
      data_sources_source=[input_file],
      data_sources_target=None,
      shuffle=False,
      num_epochs=1)

  input_fn = training_utils.create_input_fn(
      data_provider_fn=data_provider,
      batch_size=batch_size,
      allow_smaller_final_batch=True)

  # Build the graph
  features, labels = input_fn()
  return model(
      features=features,
      labels=labels,
      params=None,
      mode=tf.contrib.learn.ModeKeys.INFER)
    def _test_with_args(self, **kwargs):
        """Helper function to test create_input_fn with keyword arguments"""
        sources_file, targets_file = test_utils.create_temp_parallel_data(
            sources=["Hello World ."], targets=["Goodbye ."])
        data_provider_fn = lambda: data_utils.make_parallel_data_provider(
            [sources_file.name], [targets_file.name])
        input_fn = training_utils.create_input_fn(
            data_provider_fn=data_provider_fn, **kwargs)
        features, labels = input_fn()

        with self.test_session() as sess:
            with tf.contrib.slim.queues.QueueRunners(sess):
                features_, labels_ = sess.run([features, labels])

        self.assertEqual(set(features_.keys()),
                         set(["source_tokens", "source_len"]))
        self.assertEqual(set(labels_.keys()),
                         set(["target_tokens", "target_len"]))
Exemple #6
0
def test_copy_gen_model(record_path, vocab_path=None):

    tf.logging.set_verbosity(tf.logging.INFO)

    vocab = Vocab(vocab_path)
    batch_size = 2

    # Build model graph
    mode = tf.contrib.learn.ModeKeys.TRAIN
    params_ = CopyGenSeq2Seq.default_params().copy()
    params_.update(TEST_PARAMS)
    params_.update({
        "vocab_source": vocab_path,
        "vocab_target": vocab_path,
    })
    print(params_)

    model = CopyGenSeq2Seq(params=params_, mode=mode, vocab_instance=vocab)

    tf.logging.info(vocab_path)

    input_pipeline_ = input_pipeline.FeaturedTFRecordInputPipeline(params={
        "files": [record_path],
        "shuffle":
        True
    },
                                                                   mode=mode)
    input_fn = training_utils.create_input_fn(pipeline=input_pipeline_,
                                              batch_size=batch_size)
    features, labels = input_fn()
    fetches = model(features, labels, None)
    fetches = [_ for _ in fetches if _ is not None]
    from tensorflow.python import debug as tf_debug

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())
        with tf.contrib.slim.queues.QueueRunners(sess):
            # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            fetches_ = sess.run(fetches)
            print("yes")

    return model, fetches_
Exemple #7
0
    def test_pipeline(self):
        # Create source and target example
        source_len = 10
        target_len = self.max_decode_length + 10
        source = " ".join(np.random.choice(self.vocab_list, source_len))
        target = " ".join(np.random.choice(self.vocab_list, target_len))
        tfrecords_file = test_utils.create_temp_tfrecords(source=source,
                                                          target=target)

        # Build model graph
        model = self.create_model()
        featurizer = model.create_featurizer()
        data_provider = lambda: inputs.make_data_provider(
            [tfrecords_file.name])
        input_fn = training_utils.create_input_fn(data_provider, featurizer,
                                                  self.batch_size)
        features, labels = input_fn()
        predictions, loss, train_op = model(features, labels, None,
                                            tf.contrib.learn.ModeKeys.TRAIN)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(tf.initialize_all_tables())
            with tf.contrib.slim.queues.QueueRunners(sess):
                predictions_, loss_, _ = sess.run(
                    [predictions, loss, train_op])

        # We have predictions for each target words and the SEQUENCE_START token.
        # That's why it's `target_len + 1`
        max_decode_length = model.params["target.max_seq_len"]
        expected_decode_len = np.minimum(target_len + 1, max_decode_length)

        np.testing.assert_array_equal(predictions_["logits"].shape, [
            self.batch_size, expected_decode_len,
            model.target_vocab_info.total_size
        ])
        np.testing.assert_array_equal(predictions_["predictions"].shape,
                                      [self.batch_size, expected_decode_len])
        self.assertFalse(np.isnan(loss_))

        tfrecords_file.close()
Exemple #8
0
  def _test_with_args(self, **kwargs):
    """Helper function to test create_input_fn with keyword arguments"""
    sources_file, targets_file = test_utils.create_temp_parallel_data(
        sources=["Hello World ."], targets=["Goodbye ."])

    pipeline = input_pipeline.ParallelTextInputPipeline(
        params={
            "source_files": [sources_file.name],
            "target_files": [targets_file.name]
        },
        mode=tf.contrib.learn.ModeKeys.TRAIN)
    input_fn = training_utils.create_input_fn(pipeline=pipeline, **kwargs)
    features, labels = input_fn()

    with self.test_session() as sess:
      with tf.contrib.slim.queues.QueueRunners(sess):
        features_, labels_ = sess.run([features, labels])

    self.assertEqual(
        set(features_.keys()), set(["source_tokens", "source_len"]))
    self.assertEqual(set(labels_.keys()), set(["target_tokens", "target_len"]))
Exemple #9
0
def test_model(source_path, target_path, vocab_path):

    tf.logging.set_verbosity(tf.logging.INFO)
    batch_size = 2

    # Build model graph
    mode = tf.contrib.learn.ModeKeys.TRAIN
    params_ = AttentionSeq2Seq.default_params().copy()
    params_.update({
        "vocab_source": vocab_path,
        "vocab_target": vocab_path,
    })
    model = AttentionSeq2Seq(params=params_, mode=mode)

    tf.logging.info(vocab_path)

    input_pipeline_ = input_pipeline.ParallelTextInputPipeline(params={
        "source_files": [source_path],
        "target_files": [target_path]
    },
                                                               mode=mode)
    input_fn = training_utils.create_input_fn(pipeline=input_pipeline_,
                                              batch_size=batch_size)
    features, labels = input_fn()
    fetches = model(features, labels, None)

    fetches = [_ for _ in fetches if _ is not None]

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())
        with tf.contrib.slim.queues.QueueRunners(sess):
            fetches_ = sess.run(fetches)

    return model, fetches_
Exemple #10
0
  def _test_pipeline(self, mode, params=None):
    """Helper function to test the full model pipeline.
    """
    # Create source and target example
    source_len = self.sequence_length + 5
    target_len = self.sequence_length + 10
    source = " ".join(np.random.choice(self.vocab_list, source_len))
    target = " ".join(np.random.choice(self.vocab_list, target_len))
    sources_file, targets_file = test_utils.create_temp_parallel_data(
        sources=[source], targets=[target])

    # Build model graph
    model = self.create_model(mode, params)
    input_pipeline_ = input_pipeline.ParallelTextInputPipeline(
        params={
            "source_files": [sources_file.name],
            "target_files": [targets_file.name]
        },
        mode=mode)
    input_fn = training_utils.create_input_fn(
        pipeline=input_pipeline_, batch_size=self.batch_size)
    features, labels = input_fn()
    fetches = model(features, labels, None)
    fetches = [_ for _ in fetches if _ is not None]

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())
      with tf.contrib.slim.queues.QueueRunners(sess):
        fetches_ = sess.run(fetches)

    sources_file.close()
    targets_file.close()

    return model, fetches_
Exemple #11
0
def create_experiment(output_dir):
    """
  Creates a new Experiment instance.

  Args:
    output_dir: Output directory for model checkpoints and summaries.
  """

    config = run_config.RunConfig(
        tf_random_seed=FLAGS.tf_random_seed,
        save_checkpoints_secs=FLAGS.save_checkpoints_secs,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
        gpu_memory_fraction=FLAGS.gpu_memory_fraction)
    config.tf_config.gpu_options.allow_growth = FLAGS.gpu_allow_growth
    config.tf_config.log_device_placement = FLAGS.log_device_placement

    train_options = training_utils.TrainOptions(
        model_class=FLAGS.model, model_params=FLAGS.model_params)
    # On the main worker, save training options
    if config.is_chief:
        gfile.MakeDirs(output_dir)
        train_options.dump(output_dir)

    bucket_boundaries = None
    if FLAGS.buckets:
        bucket_boundaries = list(map(int, FLAGS.buckets.split(",")))

    # Training data input pipeline
    train_input_pipeline = input_pipeline.make_input_pipeline_from_def(
        def_dict=FLAGS.input_pipeline_train,
        mode=tf.contrib.learn.ModeKeys.TRAIN)

    # Create training input function
    train_input_fn = training_utils.create_input_fn(
        pipeline=train_input_pipeline,
        batch_size=FLAGS.batch_size,
        bucket_boundaries=bucket_boundaries,
        mode=tf.contrib.learn.ModeKeys.TRAIN)

    # Development data input pipeline
    dev_input_pipeline = input_pipeline.make_input_pipeline_from_def(
        def_dict=FLAGS.input_pipeline_dev,
        mode=tf.contrib.learn.ModeKeys.EVAL,
        shuffle=False,
        num_epochs=1)

    # Create eval input function
    eval_input_fn = training_utils.create_input_fn(
        pipeline=dev_input_pipeline,
        batch_size=FLAGS.batch_size,
        allow_smaller_final_batch=True,
        mode=tf.contrib.learn.ModeKeys.EVAL)

    def model_fn(features, labels, params, mode):
        """Builds the model graph"""
        model = _create_from_dict(
            {
                "class": train_options.model_class,
                "params": train_options.model_params
            },
            models,
            mode=mode)
        return model(features, labels, params)

    estimator = tf.contrib.learn.Estimator(model_fn=model_fn,
                                           model_dir=output_dir,
                                           config=config,
                                           params=FLAGS.model_params)

    # Create hooks
    train_hooks = []
    for dict_ in FLAGS.hooks:
        hook = _create_from_dict(dict_,
                                 hooks,
                                 model_dir=estimator.model_dir,
                                 run_config=config)
        train_hooks.append(hook)

    # Create metrics
    eval_metrics = {}
    for dict_ in FLAGS.metrics:
        metric = _create_from_dict(dict_, metric_specs)
        eval_metrics[metric.name] = metric

    experiment = PatchedExperiment(estimator=estimator,
                                   train_input_fn=train_input_fn,
                                   eval_input_fn=eval_input_fn,
                                   min_eval_frequency=FLAGS.eval_every_n_steps,
                                   train_steps=FLAGS.train_steps,
                                   eval_steps=None,
                                   eval_metrics=eval_metrics,
                                   train_monitors=train_hooks)

    return experiment
Exemple #12
0
def create_experiment(output_dir):
    """
  Creates a new Experiment instance.

  Args:
    output_dir: Output directory for model checkpoints and summaries.
  """

    # Load vocabulary info
    source_vocab_info = inputs.get_vocab_info(FLAGS.vocab_source)
    target_vocab_info = inputs.get_vocab_info(FLAGS.vocab_target)

    # Create data providers
    train_data_provider = lambda: inputs.make_data_provider([FLAGS.data_train])
    dev_data_provider = lambda: inputs.make_data_provider([FLAGS.data_dev])

    # Find model class
    model_class = getattr(models, FLAGS.model)

    # Parse parameter and merge with defaults
    hparams = model_class.default_params()
    if FLAGS.hparams is not None:
        hparams = HParamsParser(hparams).parse(FLAGS.hparams)

    # Print hyperparameter values
    tf.logging.info("Model Hyperparameters")
    tf.logging.info("=" * 50)
    for param, value in sorted(hparams.items()):
        tf.logging.info("%s=%s", param, value)
    tf.logging.info("=" * 50)

    # Create model
    model = model_class(source_vocab_info=source_vocab_info,
                        target_vocab_info=target_vocab_info,
                        params=hparams)
    featurizer = model.create_featurizer()

    bucket_boundaries = None
    if FLAGS.buckets:
        bucket_boundaries = list(map(int, FLAGS.buckets.split(",")))

    # Create input functions
    train_input_fn = training_utils.create_input_fn(
        train_data_provider,
        featurizer,
        FLAGS.batch_size,
        bucket_boundaries=bucket_boundaries)
    eval_input_fn = training_utils.create_input_fn(dev_data_provider,
                                                   featurizer,
                                                   FLAGS.batch_size)

    def model_fn(features, labels, params, mode):
        """Builds the model graph"""
        return model(features, labels, params, mode)

    estimator = tf.contrib.learn.estimator.Estimator(model_fn=model_fn,
                                                     model_dir=output_dir)

    # Create training Hooks
    model_analysis_hook = hooks.PrintModelAnalysisHook(
        filename=os.path.join(estimator.model_dir, "model_analysis.txt"))
    train_sample_hook = hooks.TrainSampleHook(
        every_n_steps=FLAGS.sample_every_n_steps)
    metadata_hook = hooks.MetadataCaptureHook(output_dir=os.path.join(
        estimator.model_dir, "metadata"),
                                              step=10)
    train_monitors = [model_analysis_hook, train_sample_hook, metadata_hook]

    experiment = tf.contrib.learn.experiment.Experiment(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_input_fn=eval_input_fn,
        min_eval_frequency=FLAGS.eval_every_n_steps,
        train_steps=FLAGS.train_steps,
        eval_steps=FLAGS.eval_steps,
        train_monitors=train_monitors)

    return experiment
Exemple #13
0
def test_copy_gen_model(source_path=None, target_path=None, vocab_path=None):

    tf.logging.set_verbosity(tf.logging.INFO)
    batch_size = 2
    input_depth = 4
    sequence_length = 10

    if vocab_path is None:
        # Create vocabulary
        vocab_list = [str(_) for _ in range(10)]
        vocab_list += ["笑う", "泣く", "了解", "はい", "^_^"]
        vocab_size = len(vocab_list)
        vocab_file = test_utils.create_temporary_vocab_file(vocab_list)
        vocab_info = vocab.get_vocab_info(vocab_file.name)
        vocab_path = vocab_file.name
        tf.logging.info(vocab_file.name)
    else:
        vocab_info = vocab.get_vocab_info(vocab_path)
        vocab_list = get_vocab_list(vocab_path)

    extend_vocab = vocab_list + ["中国", "爱", "你"]

    tf.contrib.framework.get_or_create_global_step()
    source_len = sequence_length + 5
    target_len = sequence_length + 10
    source = " ".join(np.random.choice(extend_vocab, source_len))
    target = " ".join(np.random.choice(extend_vocab, target_len))

    is_tmp_file = False
    if source_path is None and target_path is None:
        is_tmp_file = True
        sources_file, targets_file = test_utils.create_temp_parallel_data(
            sources=[source], targets=[target])
        source_path = sources_file.name
        target_path = targets_file.name

    # Build model graph
    mode = tf.contrib.learn.ModeKeys.TRAIN
    params_ = CopyGenSeq2Seq.default_params().copy()
    params_.update({
        "vocab_source": vocab_path,
        "vocab_target": vocab_path,
    })
    model = CopyGenSeq2Seq(params=params_, mode=mode)

    tf.logging.info(source_path)
    tf.logging.info(target_path)

    input_pipeline_ = input_pipeline.ParallelTextInputPipeline(params={
        "source_files": [source_path],
        "target_files": [target_path]
    },
                                                               mode=mode)
    input_fn = training_utils.create_input_fn(pipeline=input_pipeline_,
                                              batch_size=batch_size)
    features, labels = input_fn()
    fetches = model(features, labels, None)
    fetches = [_ for _ in fetches if _ is not None]

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())
        with tf.contrib.slim.queues.QueueRunners(sess):
            fetches_ = sess.run(fetches)

    if is_tmp_file:
        sources_file.close()
        targets_file.close()

    return model, fetches_
Exemple #14
0
def create_experiment(output_dir):
    """
  Creates a new Experiment instance.

  Args:
    output_dir: Output directory for model checkpoints and summaries.
  """

    config = run_config.RunConfig(
        tf_random_seed=FLAGS.tf_random_seed,
        save_checkpoints_secs=FLAGS.save_checkpoints_secs,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours)

    # Load vocabulary info
    source_vocab_info = vocab.get_vocab_info(FLAGS.vocab_source)
    target_vocab_info = vocab.get_vocab_info(FLAGS.vocab_target)

    # Find model class
    model_class = getattr(models, FLAGS.model)

    # Parse parameter and merge with defaults
    hparams = model_class.default_params()
    if FLAGS.hparams is not None and isinstance(FLAGS.hparams, str):
        hparams = HParamsParser(hparams).parse(FLAGS.hparams)
    elif isinstance(FLAGS.hparams, dict):
        hparams.update(FLAGS.hparams)

    # Print hparams
    training_utils.print_hparams(hparams)

    # One the main worker, save training options and vocabulary
    if config.is_chief:
        # Copy vocabulary to output directory
        gfile.MakeDirs(output_dir)
        source_vocab_path = os.path.join(output_dir, "vocab_source")
        gfile.Copy(FLAGS.vocab_source, source_vocab_path, overwrite=True)
        target_vocab_path = os.path.join(output_dir, "vocab_target")
        gfile.Copy(FLAGS.vocab_target, target_vocab_path, overwrite=True)
        # Save train options
        train_options = training_utils.TrainOptions(
            hparams=hparams,
            model_class=FLAGS.model,
            source_vocab_path=source_vocab_path,
            target_vocab_path=target_vocab_path)
        train_options.dump(output_dir)

    # Create model
    model = model_class(source_vocab_info=source_vocab_info,
                        target_vocab_info=target_vocab_info,
                        params=hparams)

    bucket_boundaries = None
    if FLAGS.buckets:
        bucket_boundaries = list(map(int, FLAGS.buckets.split(",")))

    # Create training input function
    train_input_fn = training_utils.create_input_fn(
        data_provider_fn=functools.partial(
            data_utils.make_parallel_data_provider,
            data_sources_source=FLAGS.train_source,
            data_sources_target=FLAGS.train_target,
            shuffle=True,
            num_epochs=FLAGS.train_epochs,
            delimiter=FLAGS.delimiter),
        batch_size=FLAGS.batch_size,
        bucket_boundaries=bucket_boundaries)

    # Create eval input function
    eval_input_fn = training_utils.create_input_fn(
        data_provider_fn=functools.partial(
            data_utils.make_parallel_data_provider,
            data_sources_source=FLAGS.dev_source,
            data_sources_target=FLAGS.dev_target,
            shuffle=False,
            num_epochs=1,
            delimiter=FLAGS.delimiter),
        batch_size=FLAGS.batch_size)

    def model_fn(features, labels, params, mode):
        """Builds the model graph"""
        return model(features, labels, params, mode)

    estimator = tf.contrib.learn.estimator.Estimator(model_fn=model_fn,
                                                     model_dir=output_dir,
                                                     config=config)

    train_hooks = training_utils.create_default_training_hooks(
        estimator=estimator,
        sample_frequency=FLAGS.sample_every_n_steps,
        delimiter=FLAGS.delimiter)

    eval_metrics = {
        "log_perplexity": metrics.streaming_log_perplexity(),
        "bleu": metrics.make_bleu_metric_spec(),
    }

    experiment = tf.contrib.learn.experiment.Experiment(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_input_fn=eval_input_fn,
        min_eval_frequency=FLAGS.eval_every_n_steps,
        train_steps=FLAGS.train_steps,
        eval_steps=None,
        eval_metrics=eval_metrics,
        train_monitors=train_hooks)

    return experiment
Exemple #15
0
def create_experiment(output_dir):
  """
  Creates a new Experiment instance.

  Args:
    output_dir: Output directory for model checkpoints and summaries.
  """

  config = run_config.RunConfig(
      tf_random_seed=FLAGS.tf_random_seed,
      save_checkpoints_secs=FLAGS.save_checkpoints_secs,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      keep_checkpoint_max=FLAGS.keep_checkpoint_max,
      keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
      gpu_memory_fraction=FLAGS.gpu_memory_fraction)
  config.tf_config.gpu_options.allow_growth = FLAGS.gpu_allow_growth
  config.tf_config.log_device_placement = FLAGS.log_device_placement

  train_options = training_utils.TrainOptions(
      model_class=FLAGS.model,
      model_params=FLAGS.model_params)
  # On the main worker, save training options
  if config.is_chief:
    gfile.MakeDirs(output_dir)
    train_options.dump(output_dir)

  bucket_boundaries = None
  if FLAGS.buckets:
    bucket_boundaries = list(map(int, FLAGS.buckets.split(",")))

  # Training data input pipeline
  train_input_pipeline = input_pipeline.make_input_pipeline_from_def(
      def_dict=FLAGS.input_pipeline_train,
      mode=tf.contrib.learn.ModeKeys.TRAIN)

  # Create training input function
  train_input_fn = training_utils.create_input_fn(
      pipeline=train_input_pipeline,
      batch_size=FLAGS.batch_size,
      bucket_boundaries=bucket_boundaries,
      scope="train_input_fn")

  # Development data input pipeline
  dev_input_pipeline = input_pipeline.make_input_pipeline_from_def(
      def_dict=FLAGS.input_pipeline_dev,
      mode=tf.contrib.learn.ModeKeys.EVAL,
      shuffle=False, num_epochs=1)

  # Create eval input function
  eval_input_fn = training_utils.create_input_fn(
      pipeline=dev_input_pipeline,
      batch_size=FLAGS.batch_size,
      allow_smaller_final_batch=True,
      scope="dev_input_fn")


  def model_fn(features, labels, params, mode):
    """Builds the model graph"""
    model = _create_from_dict({
        "class": train_options.model_class,
        "params": train_options.model_params
    }, models, mode=mode)
    return model(features, labels, params)

  estimator = tf.contrib.learn.Estimator(
      model_fn=model_fn,
      model_dir=output_dir,
      config=config,
      params=FLAGS.model_params)

  # Create hooks
  train_hooks = []
  for dict_ in FLAGS.hooks:
    hook = _create_from_dict(
        dict_, hooks,
        model_dir=estimator.model_dir,
        run_config=config)
    train_hooks.append(hook)

  # Create metrics
  eval_metrics = {}
  for dict_ in FLAGS.metrics:
    metric = _create_from_dict(dict_, metric_specs)
    eval_metrics[metric.name] = metric

  experiment = PatchedExperiment(
      estimator=estimator,
      train_input_fn=train_input_fn,
      eval_input_fn=eval_input_fn,
      min_eval_frequency=FLAGS.eval_every_n_steps,
      train_steps=FLAGS.train_steps,
      eval_steps=None,
      eval_metrics=eval_metrics,
      train_monitors=train_hooks)

  return experiment
Exemple #16
0
def create_estimator_and_specs(output_dir):
    sessionConfig = tf.ConfigProto(log_device_placement=True,
                                   allow_soft_placement=True)
    sessionConfig.gpu_options.allow_growth = FLAGS.gpu_allow_growth
    sessionConfig.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction
    config = tf.estimator.RunConfig(
        tf_random_seed=FLAGS.tf_random_seed,
        save_checkpoints_secs=FLAGS.save_checkpoints_secs,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        session_config=sessionConfig,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours)

    train_options = training_utils.TrainOptions(
        model_class=FLAGS.model, model_params=FLAGS.model_params)
    # On the main worker, save training options
    if config.is_chief:
        gfile.MakeDirs(output_dir)
        train_options.dump(output_dir)

    bucket_boundaries = None
    if FLAGS.buckets:
        bucket_boundaries = list(map(int, FLAGS.buckets.split(",")))

    # Training data input pipeline
    train_input_pipeline = input_pipeline.make_input_pipeline_from_def(
        def_dict=FLAGS.input_pipeline_train,
        mode=tf.contrib.learn.ModeKeys.TRAIN)

    # Create training input function
    train_input_fn = training_utils.create_input_fn(
        pipeline=train_input_pipeline,
        batch_size=FLAGS.batch_size,
        bucket_boundaries=bucket_boundaries,
        scope="train_input_fn")

    # Development data input pipeline
    dev_input_pipeline = input_pipeline.make_input_pipeline_from_def(
        def_dict=FLAGS.input_pipeline_dev,
        mode=tf.contrib.learn.ModeKeys.EVAL,
        shuffle=False,
        num_epochs=1)

    # Create eval input function
    eval_input_fn = training_utils.create_input_fn(
        pipeline=dev_input_pipeline,
        batch_size=FLAGS.batch_size,
        allow_smaller_final_batch=True,
        scope="dev_input_fn")

    def model_fn(features, labels, params, mode):
        """Builds the model graph"""
        model = _create_from_dict(
            {
                "class": train_options.model_class,
                "params": train_options.model_params
            },
            models,
            mode=mode)
        (predictions, loss, train_op) = model(features, labels, params)

        # Create metrics
        eval_metrics = {}
        for dict_ in FLAGS.metrics:
            metric = _create_from_dict(dict_, metric_specs)
            eval_metrics[metric.name] = metric(features, labels, predictions)

        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=predictions,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metrics)

    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=output_dir,
                                       config=config,
                                       params=FLAGS.model_params)

    # Create hooks
    train_hooks = []
    for dict_ in FLAGS.hooks:
        hook = _create_from_dict(dict_,
                                 hooks,
                                 model_dir=estimator.model_dir,
                                 run_config=config)
        train_hooks.append(hook)

    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn,
                                        max_steps=FLAGS.train_steps,
                                        hooks=train_hooks)
    eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
    return (estimator, train_spec, eval_spec)