예제 #1
0
    def testIrregularBatches(self):
        batch_size = 12
        dataset = tf.data.Dataset.range(batch_size * 2 - 1)
        dataset = dataset.map(lambda x: {"x": x, "y": x + 1})
        dataset = dataset.batch(batch_size)
        dataset = dataset.apply(data.filter_irregular_batches(batch_size))

        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()

        with self.test_session() as sess:
            single_element = sess.run(next_element)
            self.assertEqual(batch_size, single_element["x"].size)
            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(next_element)
예제 #2
0
    def _input_fn_impl(self,
                       mode,
                       batch_size,
                       metadata,
                       features_file,
                       labels_file=None,
                       batch_type="examples",
                       batch_multiplier=1,
                       bucket_width=None,
                       single_pass=False,
                       num_threads=None,
                       sample_buffer_size=None,
                       prefetch_buffer_size=None,
                       maximum_features_length=None,
                       maximum_labels_length=None):
        """See ``input_fn``."""
        self._initialize(metadata)

        feat_dataset, feat_process_fn, feat_padded_shapes_fn = self._get_features_builder(
            features_file)

        if labels_file is None:
            dataset = feat_dataset
            # Parallel inputs must be catched in a single tuple and not considered as multiple arguments.
            process_fn = lambda *arg: feat_process_fn(item_or_tuple(arg))
            padded_shapes_fn = feat_padded_shapes_fn
        else:
            labels_dataset, labels_process_fn, labels_padded_shapes_fn = (
                self._get_labels_builder(labels_file))

            dataset = tf.data.Dataset.zip((feat_dataset, labels_dataset))
            process_fn = lambda features, labels: (feat_process_fn(features),
                                                   labels_process_fn(labels))
            padded_shapes_fn = lambda: (feat_padded_shapes_fn(),
                                        labels_padded_shapes_fn())

        if mode == tf.estimator.ModeKeys.TRAIN:
            dataset_size = self._get_dataset_size(features_file)
            if sample_buffer_size < dataset_size:
                # When the sample buffer size is smaller than the dataset size, shard
                # the dataset in a random order. This ensures that all parts of the
                # dataset can be seen when the evaluation frequency is high.
                dataset = dataset.apply(
                    data.random_shard(sample_buffer_size, dataset_size))
            dataset = dataset.shuffle(sample_buffer_size)
            dataset = dataset.map(process_fn,
                                  num_parallel_calls=num_threads or 4)
            dataset = dataset.apply(
                data.filter_examples_by_length(
                    maximum_features_length=maximum_features_length,
                    maximum_labels_length=maximum_labels_length,
                    features_length_fn=self._get_features_length,
                    labels_length_fn=self._get_labels_length))
            dataset = dataset.apply(
                data.batch_train_dataset(
                    batch_size,
                    batch_type=batch_type,
                    batch_multiplier=batch_multiplier,
                    bucket_width=bucket_width,
                    padded_shapes=padded_shapes_fn(),
                    features_length_fn=self._get_features_length,
                    labels_length_fn=self._get_labels_length))
            dataset = dataset.apply(
                data.filter_irregular_batches(batch_multiplier))
            if not single_pass:
                dataset = dataset.repeat()
        else:
            dataset = dataset.map(process_fn,
                                  num_parallel_calls=num_threads or 1)
            dataset = dataset.padded_batch(batch_size,
                                           padded_shapes=padded_shapes_fn())

        if prefetch_buffer_size:
            dataset = dataset.prefetch(prefetch_buffer_size)

        iterator = dataset.make_initializable_iterator()

        # Add the initializer to a standard collection for it to be initialized.
        tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS,
                             iterator.initializer)

        return iterator.get_next()