def test_padding(self): batch_size = 10 dataset_size = 103 dataset = tf.data.Dataset.from_tensor_slices(np.arange(dataset_size)).map( lambda x: (x, x - 1, {"data_plus_one": x + 1})) dataset = utils.pad_to_batch( dataset.batch(batch_size, drop_remainder=False), batch_size) features = tf.data.make_one_shot_iterator(dataset).get_next() with self.test_session() as sess: # Verify all unpadded batches. for _ in range(dataset_size // batch_size): data, data_minus_1, data_dict = sess.run(features) self.assertEqual(data.shape[0], batch_size) self.assertAllEqual(data_minus_1, data - 1) self.assertAllEqual(data_dict["data_plus_one"], data + 1) # Get the final, padded batch. num_remaining_images = dataset_size % batch_size data, data_minus_1, data_dict = sess.run(features) # Verify that all of the "real" examples are properly defined. self.assertAllEqual(data_minus_1[:num_remaining_images], data[:num_remaining_images] - 1) self.assertAllEqual(data_dict["data_plus_one"][:num_remaining_images], data[:num_remaining_images] + 1) # Verify that the padding contains all zeros. padded_data = np.zeros( (batch_size - num_remaining_images,) + data.shape[1:]) self.assertAllEqual(data[num_remaining_images:], padded_data) self.assertAllEqual(data_minus_1[num_remaining_images:], padded_data) self.assertAllEqual(data_dict["data_plus_one"][num_remaining_images:], padded_data)
def __init__(self, batch_size, subset, data_dir, is_training=False): dataset_builder = tfds.builder("imagenet2012", data_dir=data_dir) dataset_builder.download_and_prepare(download_dir=data_dir) if subset == "train": dataset = dataset_builder.as_dataset(split=tfds.Split.TRAIN, shuffle_files=True) elif subset == "validation": dataset = dataset_builder.as_dataset(split=tfds.Split.VALIDATION) else: raise ValueError("subset %s is undefined " % subset) preprocess_fn = self._preprocess_fn(is_training) dataset = dataset.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) info = dataset_builder.info if is_training: # 4096 is ~0.625 GB of RAM. Reduce if memory issues encountered. dataset = dataset.shuffle(buffer_size=4096) dataset = dataset.repeat(-1 if is_training else 1) dataset = dataset.batch(batch_size, drop_remainder=is_training) if not is_training: # Pad the remainder of the last batch to make batch size fixed. dataset = utils.pad_to_batch(dataset, batch_size) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) self.images, self.labels, self.mask = iterator.get_next() self.num_classes = info.features["label"].num_classes + 1 self.class_names = ["unused"] + info.features["label"].names self.num_examples = info.splits[subset].num_examples
def test_shapes(self): batch_size = 10 dataset_size = 11 # Must not be a multiple of batch dataset = tf.data.Dataset.from_tensors(1).repeat(dataset_size) dataset = utils.pad_to_batch( dataset.batch(batch_size, drop_remainder=False), batch_size) features = tf.data.make_one_shot_iterator(dataset).get_next()[0] self.assertEqual(features.shape.as_list(), [batch_size])
def test_called_with_missing_dynamic_dimensions(self): tensor_1 = tf.placeholder(tf.int32, (None,)) tensor_2 = tf.placeholder(tf.int32, (None,)) bad_tf_dataset = tf.data.Dataset.from_tensors((tensor_1, tensor_2)) bad_tf_dataset = utils.pad_to_batch(bad_tf_dataset, 1) iterator = bad_tf_dataset.make_initializable_iterator() with self.test_session() as sess: sess.run(iterator.initializer, feed_dict={tensor_1: [1], tensor_2: [1, 2]}) with self.assertRaisesRegex( tf.errors.InvalidArgumentError, ".*Batch size of dataset tensors .* do not match.*"): sess.run(iterator.get_next())
def test_called_with_missing_static_dimensions(self): bad_tf_dataset = tf.data.Dataset.from_tensors((np.ones(1), np.ones(2))) with self.assertRaisesRegex( ValueError, "Batch size of dataset tensors does not match.*"): bad_tf_dataset = utils.pad_to_batch(bad_tf_dataset, 1)
def test_called_with_scalar(self): bad_tf_dataset = tf.data.Dataset.from_tensors(0) with self.assertRaisesRegex(ValueError, "Tensor .* is a scalar."): bad_tf_dataset = utils.pad_to_batch(bad_tf_dataset, 1)