예제 #1
0
    def _testFilterByLength(self,
                            features_length,
                            labels_length,
                            maximum_features_length=None,
                            maximum_labels_length=None,
                            filtered=True):
        dataset = tf.data.Dataset.zip(
            (tf.data.Dataset.from_tensors(tf.constant(features_length)),
             tf.data.Dataset.from_tensors(tf.constant(labels_length))))
        dataset = dataset.apply(
            data.filter_examples_by_length(
                maximum_features_length=maximum_features_length,
                maximum_labels_length=maximum_labels_length,
                features_length_fn=lambda _: features_length,
                labels_length_fn=lambda _: labels_length))

        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()

        with self.test_session() as sess:
            if filtered:
                with self.assertRaises(tf.errors.OutOfRangeError):
                    sess.run(next_element)
            else:
                sess.run(next_element)
예제 #2
0
    def _input_fn_impl(self,
                       mode,
                       batch_size,
                       metadata,
                       features_file,
                       labels_file=None,
                       batch_type="examples",
                       batch_multiplier=1,
                       bucket_width=None,
                       single_pass=False,
                       num_threads=None,
                       sample_buffer_size=None,
                       prefetch_buffer_size=None,
                       maximum_features_length=None,
                       maximum_labels_length=None):
        """See ``input_fn``."""
        self._initialize(metadata)

        feat_dataset, feat_process_fn, feat_padded_shapes_fn = self._get_features_builder(
            features_file)

        if labels_file is None:
            dataset = feat_dataset
            # Parallel inputs must be catched in a single tuple and not considered as multiple arguments.
            process_fn = lambda *arg: feat_process_fn(item_or_tuple(arg))
            padded_shapes_fn = feat_padded_shapes_fn
        else:
            labels_dataset, labels_process_fn, labels_padded_shapes_fn = (
                self._get_labels_builder(labels_file))

            dataset = tf.data.Dataset.zip((feat_dataset, labels_dataset))
            process_fn = lambda features, labels: (feat_process_fn(features),
                                                   labels_process_fn(labels))
            padded_shapes_fn = lambda: (feat_padded_shapes_fn(),
                                        labels_padded_shapes_fn())

        if mode == tf.estimator.ModeKeys.TRAIN:
            dataset_size = self._get_dataset_size(features_file)
            if sample_buffer_size < dataset_size:
                # When the sample buffer size is smaller than the dataset size, shard
                # the dataset in a random order. This ensures that all parts of the
                # dataset can be seen when the evaluation frequency is high.
                dataset = dataset.apply(
                    data.random_shard(sample_buffer_size, dataset_size))
            dataset = dataset.shuffle(sample_buffer_size)
            dataset = dataset.map(process_fn,
                                  num_parallel_calls=num_threads or 4)
            dataset = dataset.apply(
                data.filter_examples_by_length(
                    maximum_features_length=maximum_features_length,
                    maximum_labels_length=maximum_labels_length,
                    features_length_fn=self._get_features_length,
                    labels_length_fn=self._get_labels_length))
            dataset = dataset.apply(
                data.batch_train_dataset(
                    batch_size,
                    batch_type=batch_type,
                    batch_multiplier=batch_multiplier,
                    bucket_width=bucket_width,
                    padded_shapes=padded_shapes_fn(),
                    features_length_fn=self._get_features_length,
                    labels_length_fn=self._get_labels_length))
            dataset = dataset.apply(
                data.filter_irregular_batches(batch_multiplier))
            if not single_pass:
                dataset = dataset.repeat()
        else:
            dataset = dataset.map(process_fn,
                                  num_parallel_calls=num_threads or 1)
            dataset = dataset.padded_batch(batch_size,
                                           padded_shapes=padded_shapes_fn())

        if prefetch_buffer_size:
            dataset = dataset.prefetch(prefetch_buffer_size)

        iterator = dataset.make_initializable_iterator()

        # Add the initializer to a standard collection for it to be initialized.
        tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS,
                             iterator.initializer)

        return iterator.get_next()