Exemplo n.º 1
0
 def load_mnist_dataset_info(self):
   mnist_info_path = os.path.join(
       utils.tfds_path(),
       "testing/test_data/dataset_info/mnist/3.0.1",
   )
   mnist_info_path = os.path.normpath(mnist_info_path)
   self.read_from_directory(mnist_info_path)
Exemplo n.º 2
0
    def test_image_custom_decode(self):

        # Do not uses random here because Jpeg compression has loss, so decoded
        # value isn't the same
        img_shaped = np.ones(shape=(30, 60, 3), dtype=np.uint8)
        x, y, w, h = 4, 7, 10, 13
        img_cropped = img_shaped[y:y + h, x:x + w, :]

        class DecodeCrop(decode_lib.Decoder):
            """Simple class on how to customize the decoding."""
            def decode_example(self, serialized_image):
                return tf.image.decode_and_crop_jpeg(
                    serialized_image,
                    [y, x, h, w],
                    channels=self.feature.shape[-1],
                )

        @decode_lib.make_decoder()
        def decode_crop(serialized_image, feature):
            return tf.image.decode_and_crop_jpeg(
                serialized_image,
                [y, x, h, w],
                channels=feature.shape[-1],
            )

        image_path = os.fspath(
            utils.tfds_path('testing/test_data/test_image.jpg'))
        with tf.io.gfile.GFile(image_path, 'rb') as f:
            serialized_img = f.read()

        self.assertFeature(
            # Image with statically defined shape
            feature=features_lib.Image(shape=(30, 60, 3),
                                       encoding_format='jpeg'),
            shape=(30, 60, 3),
            dtype=tf.uint8,
            # Output shape is different.
            test_tensor_spec=False,
            tests=[
                testing.FeatureExpectationItem(
                    value=img_shaped,
                    expected=img_cropped,
                    shape=(13, 10, 3),  # Shape is cropped
                    decoders=DecodeCrop(),
                ),
                testing.FeatureExpectationItem(
                    value=img_shaped,
                    expected=img_cropped,
                    shape=(13, 10, 3),  # Shape is cropped
                    decoders=decode_crop(),  # pylint: disable=no-value-for-parameter
                ),
                testing.FeatureExpectationItem(
                    value=image_path,
                    expected=serialized_img,
                    shape=(),
                    dtype=tf.string,
                    decoders=decode_lib.SkipDecoding(),
                ),
            ],
        )
def dummy_register():
  """Dummy register."""

  with tempfile.TemporaryDirectory() as tmp_path:
    tmp_path = pathlib.Path(tmp_path)

    source_path = utils.tfds_path() / 'testing/dummy_dataset/dummy_dataset.py'

    # Single-file dataset package (without checksums)
    src_single = dataset_sources.DatasetSource.from_json(os.fspath(source_path))

    # Multi-file dataset package (with checksums)
    src_multi = dataset_sources.DatasetSource.from_json({
        'root_path': os.fspath(source_path.parent),
        'filenames': ['checksums.tsv', 'dummy_dataset.py'],
    })
    src_multi_json = json.dumps(src_multi.to_json())  # `dict` -> `str`

    # Create the remote index content
    # Note the absence of `"` for the `src_multi_json` as it is parsed as `dict`
    content = textwrap.dedent(
        f"""\
        {{"name": "kaggle:dummy_dataset", "source": "{src_single.to_json()}"}}
        {{"name": "kaggle:ds1", "source": "{src_single.to_json()}"}}
        {{"name": "mlds:dummy_dataset", "source": {src_multi_json}}}
        """
    )
    dummy_path = tmp_path / 'dummy-community-datasets.toml'
    dummy_path.write_text(content)

    with mock_cache_path(tmp_path / 'cache'):
      yield register_package.PackageRegister(path=dummy_path)
Exemplo n.º 4
0
def test_compute_url_info():
  filepath = utils.tfds_path() / 'testing/test_data/6pixels.png'

  expected_url_info = checksums.UrlInfo(
      checksum=
      '04f38ebed34d3b027d2683193766155912fba647158c583c3bdb4597ad8af34c',
      size=utils.Size(102),
      filename='6pixels.png',
  )
  url_info = checksums.compute_url_info(filepath, checksum_cls=hashlib.sha256)
  assert url_info == expected_url_info
  assert url_info.filename == expected_url_info.filename
Exemplo n.º 5
0
 def _checksums_path(cls) -> ReadOnlyPath:
   """Returns the checksums path."""
   # Used:
   # * To load the checksums (in url_infos)
   # * To save the checksums (in DownloadManager)
   new_path = cls.code_path.parent / "checksums.tsv"
   # Checksums of legacy datasets are located in a separate dir.
   legacy_path = utils.tfds_path() / "url_checksums" / f"{cls.name}.txt"
   if (
       # zipfile.Path does not have `.parts`. Additionally, `os.fspath`
       # will extract the file, so use `str`.
       "tensorflow_datasets" in str(new_path) and legacy_path.exists() and
       not new_path.exists()):
     return legacy_path
   else:
     return new_path
Exemplo n.º 6
0
def dummy_register():
    """Dummy register."""

    with tempfile.TemporaryDirectory() as tmp_path:
        tmp_path = pathlib.Path(tmp_path)

        # Create the remote index content
        source_path = utils.tfds_path() / 'testing/dummy_dataset'
        source_str = os.fspath(source_path)
        content = textwrap.dedent(f"""\
        {{"name": "kaggle:ds0", "source": "{source_str}"}}
        {{"name": "kaggle:ds1", "source": "{source_str}"}}
        {{"name": "mlds:ds0", "source": "{source_str}"}}
        """)
        dummy_path = tmp_path / 'dummy-community-datasets.toml'
        dummy_path.write_text(content)

        with mock_cache_path(tmp_path / 'cache'):
            yield register_package.PackageRegister(path=dummy_path)
Exemplo n.º 7
0
def _get_colormap() -> np.ndarray:
  """Loads the colormap.

  The colormap was precomputed using Glasbey et al. algorythm (Colour Displays
  for Categorical Images, 2017) to generate maximally distinct colors.

  It was generated using https://github.com/taketwo/glasbey:

  ```python
  gb = glasbey.Glasbey(
      base_palette=[(0, 0, 0), (228, 26, 28), (55, 126, 184), (77, 175, 74)],
      no_black=True,
  )
  palette = gb.generate_palette(size=256)
  gb.save_palette(palette, 'colormap.csv')
  ```

  Returns:
    colormap: A `np.array(shape=(255, 3), dtype=np.uint8)` representing the
      mapping id -> color.
  """
  colormap_path = utils.tfds_path() / 'core/features/colormap.csv'
  with colormap_path.open() as f:
    return np.array(list(csv.reader(f)), dtype=np.uint8)
Exemplo n.º 8
0
    def test_video_custom_decode(self):

        image_path = os.fspath(
            utils.tfds_path('testing/test_data/test_image.jpg'))
        with tf.io.gfile.GFile(image_path, 'rb') as f:
            serialized_img = f.read()

        self.assertFeature(
            # Image with statically defined shape
            feature=features_lib.Video(shape=(None, 30, 60, 3)),
            shape=(None, 30, 60, 3),
            dtype=tf.uint8,
            tests=[
                testing.FeatureExpectationItem(
                    value=[image_path] * 15,  # 15 frames of video
                    expected=[serialized_img] * 15,  # Non-decoded image
                    shape=(15, ),
                    dtype=tf.string,  # Only string are decoded
                    decoders=decode_lib.SkipDecoding(),
                ),
            ],
        )

        # Test with FeatureDict
        self.assertFeature(
            feature=features_lib.FeaturesDict({
                'image':
                features_lib.Image(shape=(30, 60, 3), encoding_format='jpeg'),
                'label':
                tf.int64,
            }),
            shape={
                'image': (30, 60, 3),
                'label': (),
            },
            dtype={
                'image': tf.uint8,
                'label': tf.int64,
            },
            tests=[
                testing.FeatureExpectationItem(
                    decoders={
                        'image': decode_lib.SkipDecoding(),
                    },
                    value={
                        'image': image_path,
                        'label': 123,
                    },
                    expected={
                        'image': serialized_img,
                        'label': 123,
                    },
                    shape={
                        'image': (),
                        'label': (),
                    },
                    dtype={
                        'image': tf.string,
                        'label': tf.int64,
                    },
                ),
            ],
        )
Exemplo n.º 9
0
# See the License for the specific language governing permissions and
# limitations under the License.
"""Methods to retrieve and store size/checksums associated to URLs."""

import hashlib
import io
from typing import Any, Dict, Iterable, Optional

from absl import logging
import dataclasses

from tensorflow_datasets.core import utils
from tensorflow_datasets.core.utils import type_utils

_CHECKSUM_DIRS = [
    utils.tfds_path() / 'url_checksums',
]
_CHECKSUM_SUFFIX = '.txt'


@dataclasses.dataclass(eq=True)
class UrlInfo:
    """Small wrapper around the url metadata (checksum, size).

  Attributes:
    size: Download size of the file
    checksum: Checksum of the file
    filename: Name of the file
  """
    size: utils.Size
    checksum: str
# limitations under the License.

# coding=utf-8
"""Tests for tensorflow_datasets.core.deprecated.text.subword_text_encoder."""
from __future__ import unicode_literals

import os

from absl.testing import parameterized
import tensorflow.compat.v2 as tf
from tensorflow_datasets import testing
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.deprecated.text import subword_text_encoder
from tensorflow_datasets.core.deprecated.text import text_encoder

TEST_DATA_DIR = os.path.join(utils.tfds_path(), 'testing', 'test_data')


class SubwordTextEncoderTest(parameterized.TestCase, testing.TestCase):
    def setUp(self):
        super(SubwordTextEncoderTest, self).setUp()
        # Vocab ids will be (offset for pad=0):
        #                  1       2       3      4      5
        self.vocab_list = ['foo_', 'bar_', 'foo', 'bar', '<EOS>']
        self.encoder = subword_text_encoder.SubwordTextEncoder(
            vocab_list=self.vocab_list)

    def test_vocab_size(self):
        # Bytes + pad + subwords
        self.assertEqual((256 + 1 + len(self.vocab_list)),
                         self.encoder.vocab_size)
import tensorflow.compat.v2 as tf
from tensorflow_datasets import testing
from tensorflow_datasets.core import dataset_info
from tensorflow_datasets.core import download
from tensorflow_datasets.core import features
from tensorflow_datasets.core import read_only_builder
from tensorflow_datasets.core import utils
from tensorflow_datasets.image_classification import mnist

from google.protobuf import text_format
from tensorflow_metadata.proto.v0 import schema_pb2

tf.enable_v2_behavior()

_TFDS_DIR = utils.tfds_path()
_INFO_DIR = os.path.join(_TFDS_DIR, "testing", "test_data", "dataset_info",
                         "mnist", "3.0.1")
_INFO_DIR_UNLABELED = os.path.join(_TFDS_DIR, "testing", "test_data",
                                   "dataset_info", "mnist_unlabeled", "3.0.1")
_NON_EXISTENT_DIR = os.path.join(_TFDS_DIR, "non_existent_dir")

DummyDatasetSharedGenerator = testing.DummyDatasetSharedGenerator


class RandomShapedImageGenerator(DummyDatasetSharedGenerator):
    def _info(self):
        return dataset_info.DatasetInfo(
            builder=self,
            features=features.FeaturesDict({"im": features.Image()}),
            supervised_keys=("im", "im"),
Exemplo n.º 12
0
    # exceptions are raised.
    if len(registers) == 1:
      return registers[0].builder(name, **builder_kwargs)

    if len(registers) > 1:
      raise ValueError(f'Namespace {name.namespace} has multiple registers! '
                       f'This should not happen! Registers: {registers}')

    raise registered.DatasetNotFoundError(
        f'Namespace {name.namespace} found with {len(registers)} registers, '
        f'but could not load dataset {name.name}.')

  def get_builder_root_dirs(self, name: naming.DatasetName) -> List[epath.Path]:
    """Returns root dir of the generated builder (without version/config)."""
    result = []
    registers = self.registers_per_namespace[name.namespace]
    for register in registers:
      if isinstance(register, register_path.DataDirRegister):
        result.extend(register.get_builder_root_dirs(name))
      else:
        raise RuntimeError(f'Not supported for non datadir registers ({name})!')
    return result


def registry_for_config(config_path: epath.PathLike) -> DatasetRegistry:
  return DatasetRegistry(NamespaceConfig(config_path=epath.Path(config_path)))


community_register = registry_for_config(
    config_path=(utils.tfds_path() / 'community-datasets.toml'))