Beispiel #1
0
    def _testBatchTrainDataset(self, check_fn, batch_size, **kwargs):
        num_examples = 1000
        features = tf.random_normal([num_examples], mean=12, stddev=6, seed=42)
        labels_diff = tf.random_normal([num_examples],
                                       mean=0,
                                       stddev=3,
                                       seed=42)
        labels = features + labels_diff

        features = tf.maximum(tf.to_int32(1), tf.to_int32(features))
        labels = tf.maximum(tf.to_int32(1), tf.to_int32(labels))

        dataset = tf.data.Dataset.zip(
            (tf.data.Dataset.from_tensor_slices(features),
             tf.data.Dataset.from_tensor_slices(labels)))
        dataset = dataset.apply(
            data.batch_parallel_dataset(batch_size,
                                        features_length_fn=lambda x: x,
                                        labels_length_fn=lambda x: x,
                                        **kwargs))

        iterator = dataset.make_initializable_iterator()
        next_element = iterator.get_next()

        with self.test_session() as sess:
            sess.run(iterator.initializer)
            check_fn(sess, next_element)
Beispiel #2
0
def input_fn_impl(text, model, batch_size, metadata):
    """
    Initializes the model with the metadata, creates a single-tensor dataset
    from the input text and creates a process function that convert the input into
    a sequence. Then creates an iterator that creates predictions for all input tensors.
    Will only contain one in this case because there's only one string tensor as input,
    but can also be used for generating predictions for a whole set of inputs, i.e.
    a file listing multiple inputs to generate predictions for.

    Args:
        text: The input text
        model: The trained model
        batch_size: The maximum number of inputs to generate predictions for in one iteration call.
        metadata: The config dict containg the paths to the vocabulary files.

    Returns:
        The first prediction in the iterator, i.e. the answer sequence to the
        input text.
    """
    model._initialize(metadata)

    dataset = tf.data.Dataset.from_tensor_slices([text])
    # Parallel inputs must be catched in a single tuple and not considered as multiple arguments.
    process_fn = lambda *arg: model.source_inputter.process(item_or_tuple(arg))

    dataset = dataset.map(
        process_fn,
        num_parallel_calls=1)
    dataset = dataset.apply(data.batch_parallel_dataset(batch_size))

    iterator = dataset.make_initializable_iterator()

    # Add the initializer to a standard collection for it to be initialized.
    # See https://www.tensorflow.org/api_docs/python/tf/Graph for more information
    # about tensorflow graphs.
    tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)

    return iterator.get_next()
Beispiel #3
0
    def _input_fn_impl(self,
                       mode,
                       batch_size,
                       metadata,
                       features_file,
                       labels_file=None,
                       batch_type="examples",
                       batch_multiplier=1,
                       bucket_width=None,
                       single_pass=False,
                       num_threads=None,
                       sample_buffer_size=None,
                       prefetch_buffer_size=None,
                       maximum_features_length=None,
                       maximum_labels_length=None):
        """See ``input_fn``."""
        self._initialize(metadata)

        feat_dataset, feat_process_fn = self._get_features_builder(
            features_file)

        if labels_file is None:
            dataset = feat_dataset
            # Parallel inputs must be catched in a single tuple and not considered as multiple arguments.
            process_fn = lambda *arg: feat_process_fn(item_or_tuple(arg))
        else:
            labels_dataset, labels_process_fn = self._get_labels_builder(
                labels_file)

            dataset = tf.data.Dataset.zip((feat_dataset, labels_dataset))
            process_fn = lambda features, labels: (feat_process_fn(features),
                                                   labels_process_fn(labels))

        if mode == tf.estimator.ModeKeys.TRAIN:
            dataset_size = self._get_dataset_size(features_file)
            if sample_buffer_size is not None and sample_buffer_size != 0:
                if sample_buffer_size < 0:
                    sample_buffer_size = dataset_size
                elif sample_buffer_size < dataset_size:
                    # When the sample buffer size is smaller than the dataset size, shard
                    # the dataset in a random order. This ensures that all parts of the
                    # dataset can be seen when the evaluation frequency is high.
                    dataset = dataset.apply(
                        data.random_shard(sample_buffer_size, dataset_size))
                dataset = dataset.shuffle(sample_buffer_size)
            dataset = dataset.map(process_fn,
                                  num_parallel_calls=num_threads or 4)
            dataset = dataset.apply(
                data.filter_examples_by_length(
                    maximum_features_length=maximum_features_length,
                    maximum_labels_length=maximum_labels_length,
                    features_length_fn=self._get_features_length,
                    labels_length_fn=self._get_labels_length))
            dataset = dataset.apply(
                data.batch_parallel_dataset(
                    batch_size,
                    batch_type=batch_type,
                    batch_multiplier=batch_multiplier,
                    bucket_width=bucket_width,
                    features_length_fn=self._get_features_length,
                    labels_length_fn=self._get_labels_length))
            dataset = dataset.apply(
                data.filter_irregular_batches(batch_multiplier))
            if not single_pass:
                dataset = dataset.repeat()
        else:
            dataset = dataset.map(process_fn,
                                  num_parallel_calls=num_threads or 1)
            dataset = dataset.apply(data.batch_parallel_dataset(batch_size))

        if prefetch_buffer_size:
            dataset = dataset.prefetch(prefetch_buffer_size)

        iterator = dataset.make_initializable_iterator()

        # Add the initializer to a standard collection for it to be initialized.
        tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS,
                             iterator.initializer)

        return iterator.get_next()