Exemple #1
0
    def make_inference_dataset(self,
                               features_file,
                               batch_size,
                               bucket_width=None,
                               num_threads=1,
                               prefetch_buffer_size=None):
        """Builds a dataset to be used for inference.

    For evaluation and training datasets, see
    :class:`opennmt.inputters.inputter.ExampleInputter`.

    Args:
      features_file: The test file.
      batch_size: The batch size to use.
      bucket_width: The width of the length buckets to select batch candidates
        from (for efficiency). Set ``None`` to not constrain batch formation.
      num_threads: The number of elements processed in parallel.
      prefetch_buffer_size: The number of batches to prefetch asynchronously. If
        ``None``, use an automatically tuned value on TensorFlow 1.8+ and 1 on
        older versions.

    Returns:
      A ``tf.data.Dataset``.
    """
        map_func = lambda *arg: self.make_features(item_or_tuple(arg),
                                                   training=False)
        dataset = self.make_dataset(features_file, training=False)
        dataset = inference_pipeline(dataset,
                                     batch_size,
                                     process_fn=map_func,
                                     num_threads=num_threads,
                                     prefetch_buffer_size=prefetch_buffer_size,
                                     bucket_width=bucket_width,
                                     length_fn=self.get_length)
        return dataset
Exemple #2
0
    def make_evaluation_dataset(self,
                                features_file,
                                labels_file,
                                batch_size,
                                num_threads=1,
                                prefetch_buffer_size=None):
        """Builds a dataset to be used for evaluation.

    Args:
      features_file: The evaluation source file.
      labels_file: The evaluation target file.
      batch_size: The batch size to use.
      num_threads: The number of elements processed in parallel.
      prefetch_buffer_size: The number of batches to prefetch asynchronously. If
        ``None``, use an automatically tuned value on TensorFlow 1.8+ and 1 on
        older versions.

    Returns:
      A ``tf.data.Dataset``.
    """
        map_func = lambda *arg: self.make_features(arg, training=False)
        dataset = self.make_dataset([features_file, labels_file],
                                    training=False)
        dataset = inference_pipeline(dataset,
                                     batch_size,
                                     process_fn=map_func,
                                     num_threads=num_threads,
                                     prefetch_buffer_size=prefetch_buffer_size)
        return dataset
Exemple #3
0
 def make_evaluation_dataset(self,
                             features_file,
                             labels_file,
                             batch_size,
                             num_threads=1,
                             prefetch_buffer_size=None):
   """See :meth:`opennmt.inputters.inputter.ExampleInputter.make_evaluation_dataset`."""
   _ = labels_file
   dataset = self.make_dataset(features_file, training=False)
   dataset = data.inference_pipeline(
       dataset,
       batch_size,
       process_fn=lambda x: self._generate_example(x, training=False),
       num_threads=num_threads,
       prefetch_buffer_size=prefetch_buffer_size)
   return dataset
Exemple #4
0
 def make_inference_dataset_ae(self,
                            features_file,
                            batch_size,
                            bucket_width=None,
                            num_threads=1,
                            prefetch_buffer_size=None):
   if (self.features_inputter.vocabulary is None):
       self.features_inputter.vocabulary = self.features_inputter.vocabulary_lookup()
   if (self.labels_inputter.vocabulary is None):
       self.labels_inputter.vocabulary = self.labels_inputter.vocabulary_lookup()
   dataset = tf.data.TextLineDataset(features_file)
   dataset = inference_pipeline(
           dataset,
           batch_size,
           process_fn=lambda x: self.make_inference_features(x),
           num_threads=num_threads,
           prefetch_buffer_size=prefetch_buffer_size,
           bucket_width=bucket_width,
           length_fn=lambda features: features.get("length"))
   return dataset
Exemple #5
0
 def _fn():
     self._initialize(metadata)
     dataset = inputter.make_dataset(data_file, training=training)
     if training:
         batch_size_multiple = 1
         if batch_type == "tokens" and self.dtype == tf.float16:
             batch_size_multiple = 8
         dataset = data.training_pipeline(
             dataset,
             batch_size,
             batch_type=batch_type,
             batch_multiplier=batch_multiplier,
             bucket_width=bucket_width,
             single_pass=single_pass,
             process_fn=process_fn,
             num_threads=num_threads,
             shuffle_buffer_size=sample_buffer_size,
             prefetch_buffer_size=prefetch_buffer_size,
             dataset_size=self.features_inputter.get_dataset_size(
                 features_file),
             maximum_features_length=maximum_features_length,
             maximum_labels_length=maximum_labels_length,
             features_length_fn=self.features_inputter.get_length,
             labels_length_fn=self.labels_inputter.get_length,
             batch_size_multiple=batch_size_multiple,
             num_shards=num_shards,
             shard_index=shard_index)
     else:
         dataset = data.inference_pipeline(
             dataset,
             batch_size,
             process_fn=process_fn,
             num_threads=num_threads,
             prefetch_buffer_size=prefetch_buffer_size,
             bucket_width=bucket_width,
             length_fn=self.features_inputter.get_length)
     iterator = dataset.make_initializable_iterator()
     # Add the initializer to a standard collection for it to be initialized.
     tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS,
                          iterator.initializer)
     return iterator.get_next()
Exemple #6
0
  def testReorderInferDataset(self):
    dataset = tf.data.Dataset.from_tensor_slices([8, 2, 5, 6, 7, 1, 3, 9])
    dataset = dataset.map(lambda x: {"length": x})
    dataset = data.inference_pipeline(
        dataset, 3, bucket_width=3, length_fn=lambda x: x["length"])
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()

    def _check_element(element, length, index):
      self.assertAllEqual(element["length"], length)
      self.assertAllEqual(element["index"], index)

    with self.test_session() as sess:
      elements = []
      while True:
        try:
          elements.append(sess.run(next_element))
        except tf.errors.OutOfRangeError:
          break
      self.assertEqual(len(elements), 4)
      _check_element(elements[0], [8, 6, 7], [0, 3, 4])
      _check_element(elements[1], [2, 1], [1, 5])
      _check_element(elements[2], [5, 3], [2, 6])
      _check_element(elements[3], [9], [7])
Exemple #7
0
    def _input_fn_impl(self,
                       mode,
                       batch_size,
                       metadata,
                       features_file,
                       labels_file=None,
                       batch_type="examples",
                       batch_multiplier=1,
                       bucket_width=None,
                       single_pass=False,
                       num_threads=None,
                       sample_buffer_size=None,
                       prefetch_buffer_size=None,
                       maximum_features_length=None,
                       maximum_labels_length=None):
        """See ``input_fn``."""
        self._initialize(metadata)

        feat_dataset, feat_process_fn = self._get_features_builder(
            features_file)

        if labels_file is None:
            dataset = feat_dataset
            # Parallel inputs must be catched in a single tuple and not considered as multiple arguments.
            process_fn = lambda *arg: feat_process_fn(item_or_tuple(arg))
        else:
            labels_dataset, labels_process_fn = self._get_labels_builder(
                labels_file)

            dataset = tf.data.Dataset.zip((feat_dataset, labels_dataset))
            process_fn = lambda features, labels: (feat_process_fn(features),
                                                   labels_process_fn(labels))
            dataset, process_fn = self._augment_parallel_dataset(dataset,
                                                                 process_fn,
                                                                 mode=mode)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dataset = data.training_pipeline(
                dataset,
                batch_size,
                batch_type=batch_type,
                batch_multiplier=batch_multiplier,
                bucket_width=bucket_width,
                single_pass=single_pass,
                process_fn=process_fn,
                num_threads=num_threads,
                shuffle_buffer_size=sample_buffer_size,
                prefetch_buffer_size=prefetch_buffer_size,
                dataset_size=self._get_dataset_size(features_file),
                maximum_features_length=maximum_features_length,
                maximum_labels_length=maximum_labels_length,
                features_length_fn=self._get_features_length,
                labels_length_fn=self._get_labels_length)
        else:
            dataset = data.inference_pipeline(
                dataset,
                batch_size,
                process_fn=process_fn,
                num_threads=num_threads,
                prefetch_buffer_size=prefetch_buffer_size)

        iterator = dataset.make_initializable_iterator()

        # Add the initializer to a standard collection for it to be initialized.
        tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS,
                             iterator.initializer)

        return iterator.get_next()