Ejemplo n.º 1
0
    def test_custom_data_provider(self):
        if tf.executing_eagerly():
            # dataset.make_initializable_iterator is not supported when eager
            # execution is enabled.
            return
        file_pattern = os.path.join(self.testdata_dir, '*.jpg')
        batch_size = 3
        patch_size = 8
        images_list = data_provider.provide_custom_data(
            batch_size=batch_size,
            image_file_patterns=[file_pattern, file_pattern],
            patch_size=patch_size)
        for images in images_list:
            self.assertListEqual([batch_size, patch_size, patch_size, 3],
                                 images.shape.as_list())
            self.assertEqual(tf.float32, images.dtype)

        with self.cached_session() as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.tables_initializer())
            images_out_list = sess.run(images_list)
            for images_out in images_out_list:
                self.assertTupleEqual((batch_size, patch_size, patch_size, 3),
                                      images_out.shape)
                self.assertTrue(np.all(np.abs(images_out) <= 1.0))
Ejemplo n.º 2
0
def _get_data(image_set_x_file_pattern, image_set_y_file_pattern, batch_size,
              patch_size):
  """Returns image Tensors from a custom provider or TFDS."""
  if image_set_x_file_pattern and image_set_y_file_pattern:
    image_file_patterns = [image_set_x_file_pattern, image_set_y_file_pattern]
  else:
    if image_set_x_file_pattern or image_set_y_file_pattern:
      raise ValueError('Both image patterns or neither must be provided.')
    image_file_patterns = None
  images_x, images_y = data_provider.provide_custom_data(
      batch_size=batch_size,
      image_file_patterns=image_file_patterns,
      patch_size=patch_size)

  return images_x, images_y
Ejemplo n.º 3
0
  def test_custom_data_provider(self):
    file_pattern = os.path.join(self.testdata_dir, '*.jpg')
    batch_size = 3
    patch_size = 8
    images_list = data_provider.provide_custom_data(
        batch_size=batch_size,
        image_file_patterns=[file_pattern, file_pattern],
        patch_size=patch_size)
    for images in images_list:
      self.assertListEqual([batch_size, patch_size, patch_size, 3],
                           images.shape.as_list())
      self.assertEqual(tf.float32, images.dtype)

    with self.cached_session() as sess:
      sess.run(tf.compat.v1.local_variables_initializer())
      sess.run(tf.compat.v1.tables_initializer())
      images_out_list = sess.run(images_list)
      for images_out in images_out_list:
        self.assertTupleEqual((batch_size, patch_size, patch_size, 3),
                              images_out.shape)
        self.assertTrue(np.all(np.abs(images_out) <= 1.0))