Ejemplo n.º 1
0
def test_registered_logger_is_called(
        dummy_builder: dataset_builder.DatasetBuilder,  # pylint: disable=redefined-outer-name
):
    logger = mock.MagicMock()
    tfds_logging.register(logger)

    read_config = read_config_lib.ReadConfig(add_tfds_id=True)
    read_config.try_autocache = False
    read_config.num_parallel_calls_for_decode = 42
    ds = dummy_builder.as_dataset(
        split='train',
        read_config=read_config,
        as_supervised=True,
    )
    # Logging doesn't change the result:
    assert ds.element_spec == (
        tf.TensorSpec(shape=(), dtype=tf.int64),
        tf.TensorSpec(shape=(), dtype=tf.int64),
    )
    # Logger was indeed called:
    assert logger.as_dataset.call_args_list == [
        mock.call(
            dataset_name='dummy_dataset',
            config_name='',
            version='1.0.0',
            data_path=mock.ANY,
            split='train',
            shuffle_files=False,
            as_supervised=True,
            batch_size=None,
            decoders=None,
            read_config=read_config,
        )
    ]
Ejemplo n.º 2
0
def test_add_tfds_id(dummy_builder: dataset_builder.DatasetBuilder):  # pylint: disable=redefined-outer-name
    """Tests `add_tfds_id=True`."""
    read_config = read_config_lib.ReadConfig(add_tfds_id=True)
    ds = dummy_builder.as_dataset(split='train', read_config=read_config)
    assert ds.element_spec == {
        'id': tf.TensorSpec(shape=(), dtype=tf.int64),
        'tfds_id': tf.TensorSpec(shape=(), dtype=tf.string),
    }
    assert list(dataset_utils.as_numpy(ds)) == [
        {
            'id': 0,
            'tfds_id': b'dummy_dataset-train.tfrecord-00000-of-00001__0'
        },
        {
            'id': 1,
            'tfds_id': b'dummy_dataset-train.tfrecord-00000-of-00001__1'
        },
        {
            'id': 2,
            'tfds_id': b'dummy_dataset-train.tfrecord-00000-of-00001__2'
        },
    ]

    # Subsplit API works too
    ds = dummy_builder.as_dataset(split='train[1:]', read_config=read_config)
    assert ds.element_spec == {
        'id': tf.TensorSpec(shape=(), dtype=tf.int64),
        'tfds_id': tf.TensorSpec(shape=(), dtype=tf.string),
    }
    assert list(dataset_utils.as_numpy(ds)) == [
        {
            'id': 1,
            'tfds_id': b'dummy_dataset-train.tfrecord-00000-of-00001__1'
        },
        {
            'id': 2,
            'tfds_id': b'dummy_dataset-train.tfrecord-00000-of-00001__2'
        },
    ]
Ejemplo n.º 3
0
def test_add_tfds_id_as_supervised(
    dummy_builder: dataset_builder.DatasetBuilder,  # pylint: disable=redefined-outer-name
):
  """Tests `add_tfds_id=True` with `as_supervised=True`."""
  read_config = read_config_lib.ReadConfig(add_tfds_id=True)
  ds = dummy_builder.as_dataset(
      split='train', read_config=read_config, as_supervised=True,
  )
  # `add_tfds_id=True` is ignored when `as_supervised=True`
  assert ds.element_spec == (
      tf.TensorSpec(shape=(), dtype=tf.int64),
      tf.TensorSpec(shape=(), dtype=tf.int64),
  )