def _testFilterByLength(self, features_length, labels_length, maximum_features_length=None, maximum_labels_length=None, filtered=True): dataset = tf.data.Dataset.zip( (tf.data.Dataset.from_tensors(tf.constant(features_length)), tf.data.Dataset.from_tensors(tf.constant(labels_length)))) dataset = dataset.apply( data.filter_examples_by_length( maximum_features_length=maximum_features_length, maximum_labels_length=maximum_labels_length, features_length_fn=lambda _: features_length, labels_length_fn=lambda _: labels_length)) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.test_session() as sess: if filtered: with self.assertRaises(tf.errors.OutOfRangeError): sess.run(next_element) else: sess.run(next_element)
def _input_fn_impl(self, mode, batch_size, metadata, features_file, labels_file=None, batch_type="examples", batch_multiplier=1, bucket_width=None, single_pass=False, num_threads=None, sample_buffer_size=None, prefetch_buffer_size=None, maximum_features_length=None, maximum_labels_length=None): """See ``input_fn``.""" self._initialize(metadata) feat_dataset, feat_process_fn, feat_padded_shapes_fn = self._get_features_builder( features_file) if labels_file is None: dataset = feat_dataset # Parallel inputs must be catched in a single tuple and not considered as multiple arguments. process_fn = lambda *arg: feat_process_fn(item_or_tuple(arg)) padded_shapes_fn = feat_padded_shapes_fn else: labels_dataset, labels_process_fn, labels_padded_shapes_fn = ( self._get_labels_builder(labels_file)) dataset = tf.data.Dataset.zip((feat_dataset, labels_dataset)) process_fn = lambda features, labels: (feat_process_fn(features), labels_process_fn(labels)) padded_shapes_fn = lambda: (feat_padded_shapes_fn(), labels_padded_shapes_fn()) if mode == tf.estimator.ModeKeys.TRAIN: dataset_size = self._get_dataset_size(features_file) if sample_buffer_size < dataset_size: # When the sample buffer size is smaller than the dataset size, shard # the dataset in a random order. This ensures that all parts of the # dataset can be seen when the evaluation frequency is high. dataset = dataset.apply( data.random_shard(sample_buffer_size, dataset_size)) dataset = dataset.shuffle(sample_buffer_size) dataset = dataset.map(process_fn, num_parallel_calls=num_threads or 4) dataset = dataset.apply( data.filter_examples_by_length( maximum_features_length=maximum_features_length, maximum_labels_length=maximum_labels_length, features_length_fn=self._get_features_length, labels_length_fn=self._get_labels_length)) dataset = dataset.apply( data.batch_train_dataset( batch_size, batch_type=batch_type, batch_multiplier=batch_multiplier, bucket_width=bucket_width, padded_shapes=padded_shapes_fn(), features_length_fn=self._get_features_length, labels_length_fn=self._get_labels_length)) dataset = dataset.apply( data.filter_irregular_batches(batch_multiplier)) if not single_pass: dataset = dataset.repeat() else: dataset = dataset.map(process_fn, num_parallel_calls=num_threads or 1) dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes_fn()) if prefetch_buffer_size: dataset = dataset.prefetch(prefetch_buffer_size) iterator = dataset.make_initializable_iterator() # Add the initializer to a standard collection for it to be initialized. tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) return iterator.get_next()