コード例 #1
0
  def test_padding(self):
    batch_size = 10
    dataset_size = 103
    dataset = tf.data.Dataset.from_tensor_slices(np.arange(dataset_size)).map(
        lambda x: (x, x - 1, {"data_plus_one": x + 1}))
    dataset = utils.pad_to_batch(
        dataset.batch(batch_size, drop_remainder=False), batch_size)
    features = tf.data.make_one_shot_iterator(dataset).get_next()
    with self.test_session() as sess:
      # Verify all unpadded batches.
      for _ in range(dataset_size // batch_size):
        data, data_minus_1, data_dict = sess.run(features)
        self.assertEqual(data.shape[0], batch_size)
        self.assertAllEqual(data_minus_1, data - 1)
        self.assertAllEqual(data_dict["data_plus_one"], data + 1)

      # Get the final, padded batch.
      num_remaining_images = dataset_size % batch_size
      data, data_minus_1, data_dict = sess.run(features)

      # Verify that all of the "real" examples are properly defined.
      self.assertAllEqual(data_minus_1[:num_remaining_images],
                          data[:num_remaining_images] - 1)
      self.assertAllEqual(data_dict["data_plus_one"][:num_remaining_images],
                          data[:num_remaining_images] + 1)

      # Verify that the padding contains all zeros.
      padded_data = np.zeros(
          (batch_size - num_remaining_images,) + data.shape[1:])
      self.assertAllEqual(data[num_remaining_images:], padded_data)
      self.assertAllEqual(data_minus_1[num_remaining_images:], padded_data)
      self.assertAllEqual(data_dict["data_plus_one"][num_remaining_images:],
                          padded_data)
コード例 #2
0
    def __init__(self, batch_size, subset, data_dir, is_training=False):
        dataset_builder = tfds.builder("imagenet2012", data_dir=data_dir)
        dataset_builder.download_and_prepare(download_dir=data_dir)
        if subset == "train":
            dataset = dataset_builder.as_dataset(split=tfds.Split.TRAIN,
                                                 shuffle_files=True)
        elif subset == "validation":
            dataset = dataset_builder.as_dataset(split=tfds.Split.VALIDATION)
        else:
            raise ValueError("subset %s is undefined " % subset)
        preprocess_fn = self._preprocess_fn(is_training)
        dataset = dataset.map(preprocess_fn,
                              num_parallel_calls=tf.data.experimental.AUTOTUNE)
        info = dataset_builder.info
        if is_training:
            # 4096 is ~0.625 GB of RAM. Reduce if memory issues encountered.
            dataset = dataset.shuffle(buffer_size=4096)
        dataset = dataset.repeat(-1 if is_training else 1)
        dataset = dataset.batch(batch_size, drop_remainder=is_training)

        if not is_training:
            # Pad the remainder of the last batch to make batch size fixed.
            dataset = utils.pad_to_batch(dataset, batch_size)

        dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
        self.images, self.labels, self.mask = iterator.get_next()
        self.num_classes = info.features["label"].num_classes + 1
        self.class_names = ["unused"] + info.features["label"].names
        self.num_examples = info.splits[subset].num_examples
コード例 #3
0
 def test_shapes(self):
     batch_size = 10
     dataset_size = 11  # Must not be a multiple of batch
     dataset = tf.data.Dataset.from_tensors(1).repeat(dataset_size)
     dataset = utils.pad_to_batch(
         dataset.batch(batch_size, drop_remainder=False), batch_size)
     features = tf.data.make_one_shot_iterator(dataset).get_next()[0]
     self.assertEqual(features.shape.as_list(), [batch_size])
コード例 #4
0
 def test_called_with_missing_dynamic_dimensions(self):
   tensor_1 = tf.placeholder(tf.int32, (None,))
   tensor_2 = tf.placeholder(tf.int32, (None,))
   bad_tf_dataset = tf.data.Dataset.from_tensors((tensor_1, tensor_2))
   bad_tf_dataset = utils.pad_to_batch(bad_tf_dataset, 1)
   iterator = bad_tf_dataset.make_initializable_iterator()
   with self.test_session() as sess:
     sess.run(iterator.initializer,
              feed_dict={tensor_1: [1], tensor_2: [1, 2]})
     with self.assertRaisesRegex(
         tf.errors.InvalidArgumentError,
         ".*Batch size of dataset tensors .* do not match.*"):
       sess.run(iterator.get_next())
コード例 #5
0
 def test_called_with_missing_static_dimensions(self):
     bad_tf_dataset = tf.data.Dataset.from_tensors((np.ones(1), np.ones(2)))
     with self.assertRaisesRegex(
             ValueError, "Batch size of dataset tensors does not match.*"):
         bad_tf_dataset = utils.pad_to_batch(bad_tf_dataset, 1)
コード例 #6
0
 def test_called_with_scalar(self):
     bad_tf_dataset = tf.data.Dataset.from_tensors(0)
     with self.assertRaisesRegex(ValueError, "Tensor .* is a scalar."):
         bad_tf_dataset = utils.pad_to_batch(bad_tf_dataset, 1)