示例#1
0
文件: voc.py 项目: sbl1996/hanser
def make_voc_dataset_sub(n_train,
                         n_val,
                         batch_size,
                         eval_batch_size,
                         transform,
                         data_dir=None,
                         prefetch=True):
    steps_per_epoch, val_steps = n_train // batch_size, n_val // eval_batch_size

    ds_train = tfds.load("voc/2012",
                         split=f"train[:{n_train}]",
                         data_dir=data_dir,
                         shuffle_files=True,
                         read_config=tfds.ReadConfig(try_autocache=False,
                                                     skip_prefetch=True))
    ds_val = tfds.load("voc/2012",
                       split=f"train[:{n_val}]",
                       data_dir=data_dir,
                       shuffle_files=False,
                       read_config=tfds.ReadConfig(try_autocache=False,
                                                   skip_prefetch=True))
    ds_train = prepare(ds_train,
                       batch_size,
                       transform(training=True),
                       training=True,
                       repeat=False,
                       prefetch=prefetch)
    ds_val = prepare(ds_val,
                     eval_batch_size,
                     transform(training=False),
                     training=False,
                     repeat=False,
                     drop_remainder=False,
                     prefetch=prefetch)
    return ds_train, ds_val, steps_per_epoch, val_steps
示例#2
0
def _get_tfds_dataset(
        dataset: str,
        rng: np.ndarray) -> Tuple[tf.data.Dataset, tf.data.Dataset, int]:
    """Loads a TFDS dataset."""

    dataset_builder = tfds.builder(dataset)
    num_classes = 0
    if "label" in dataset_builder.info.features:
        num_classes = dataset_builder.info.features["label"].num_classes

    # Make sure each host uses a different RNG for the training data.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.host_id())
    data_rng, shuffle_rng = jax.random.split(data_rng)
    train_split = deterministic_data.get_read_instruction_for_host(
        "train", dataset_builder.info.splits["train"].num_examples)
    train_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[0])
    train_ds = dataset_builder.as_dataset(split=train_split,
                                          shuffle_files=True,
                                          read_config=train_read_config)

    eval_split_name = {
        "cifar10": "test",
        "imagenet2012": "validation"
    }.get(dataset, "test")

    eval_split_size = dataset_builder.info.splits[eval_split_name].num_examples
    eval_split = deterministic_data.get_read_instruction_for_host(
        eval_split_name, eval_split_size)
    eval_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[1])
    eval_ds = dataset_builder.as_dataset(split=eval_split,
                                         shuffle_files=False,
                                         read_config=eval_read_config)
    return train_ds, eval_ds, num_classes
示例#3
0
def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
               tfds_split: Text,
               tfds_skip_decoding_feature: Text,
               tfds_as_supervised: bool,
               input_context: Optional[tf.distribute.InputContext] = None,
               seed: Optional[Union[int, tf.Tensor]] = None,
               is_training: bool = False,
               cache: bool = False,
               cycle_length: Optional[int] = None,
               block_length: Optional[int] = None) -> tf.data.Dataset:
    """Reads a dataset from tfds."""
    # No op if exist.
    tfds_builder.download_and_prepare()
    decoders = {}
    if tfds_skip_decoding_feature:
        for skip_feature in tfds_skip_decoding_feature.split(','):
            decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
    if tfds_builder.info.splits:
        num_shards = len(
            tfds_builder.info.splits[tfds_split].file_instructions)
    else:
        # The tfds mock path often does not provide splits.
        num_shards = 1
    if input_context and num_shards < input_context.num_input_pipelines:
        # The number of files in the dataset split is smaller than the number of
        # input pipelines. We read the entire dataset first and then shard in the
        # host memory.
        read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length,
                                      interleave_block_length=block_length,
                                      input_context=None,
                                      shuffle_seed=seed)
        dataset = tfds_builder.as_dataset(split=tfds_split,
                                          shuffle_files=is_training,
                                          as_supervised=tfds_as_supervised,
                                          decoders=decoders,
                                          read_config=read_config)
        dataset = dataset.shard(input_context.num_input_pipelines,
                                input_context.input_pipeline_id)
    else:
        read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length,
                                      interleave_block_length=block_length,
                                      input_context=input_context,
                                      shuffle_seed=seed)
        dataset = tfds_builder.as_dataset(split=tfds_split,
                                          shuffle_files=is_training,
                                          as_supervised=tfds_as_supervised,
                                          decoders=decoders,
                                          read_config=read_config)

    if is_training and not cache:
        dataset = dataset.repeat()
    return dataset
示例#4
0
 def __init__(self):
     import tensorflow_datasets as tfds
     train = tfds.load(
         'imdb_reviews',
         split='train',
         read_config=tfds.ReadConfig(shuffle_seed=seed,
                                     shuffle_reshuffle_each_iteration=True),
     )
     test = tfds.load(
         'imdb_reviews',
         split='test',
         read_config=tfds.ReadConfig(shuffle_seed=seed,
                                     shuffle_reshuffle_each_iteration=True),
     )
     print(train)
示例#5
0
def _load_dataset(split, should_shuffle, data_rng, data_dir):
    """Loads a dataset split from TFDS."""
    if should_shuffle:
        file_data_rng, dataset_data_rng = jax.random.split(data_rng)
        file_data_rng = file_data_rng[0]
        dataset_data_rng = dataset_data_rng[0]
    else:
        file_data_rng = None
        dataset_data_rng = None

    read_config = tfds.ReadConfig(add_tfds_id=True, shuffle_seed=file_data_rng)
    dataset = tfds.load('ogbg_molpcba',
                        split='train' if split == 'eval_train' else split,
                        shuffle_files=should_shuffle,
                        read_config=read_config,
                        data_dir=data_dir)

    if should_shuffle:
        dataset = dataset.shuffle(seed=dataset_data_rng, buffer_size=2**15)
        dataset = dataset.repeat()

    # We do not need to worry about repeating the dataset for evaluations because
    # we call itertools.cycle on the eval iterator, which stored the iterator in
    # memory to be repeated through.
    return dataset
示例#6
0
def _merge_langs_dataset_fn(split, shuffle_files, tfds_name, langs, seed):
  """Creates a single dataset containing all languages in a tfds dataset.

  Loads individual language datasets for the specified languages and
  concatenates them into a single dataset. This is done so that we can have a
  single test task which contains all the languages in order to compute a single
  overall performance metric across all languages.

  Args:
    split: string, which split to use.
    shuffle_files: boolean, whether to shuffle the input files.
    tfds_name: string, name of the TFDS dataset to load.
    langs: list of strings, the languages to load.

  Returns:
    A tf.data.Dataset.
  """
  ds = []
  for lang in langs:
    ds.append(tfds.load("{}/{}".format(tfds_name, lang),
                        split=split,
                        shuffle_files=shuffle_files,
                        read_config=tfds.ReadConfig(
                          shuffle_seed=seed)
                        ))
  all_langs_data = ds[0]
  for lang_data in ds[1:]:
    all_langs_data = all_langs_data.concatenate(lang_data)

  return all_langs_data
示例#7
0
  def test_disable_ordering_guard(self):
    # Shuffling a sliced dataset should still yield the same examples.
    source = 'fungi'
    split = 'valid'
    label = 13
    offset = 994
    shuffle_seed = 1234

    tfds_dataset = tfds.builder(
        'meta_dataset',
        config=source,
        data_dir=FLAGS.tfds_path
    ).as_class_dataset(
        md_version='v1',
        meta_split=split,
        relative_label=label,
        decoders={'image': tfds.decode.SkipDecoding()},
        batch_size=-1,
        shuffle_files=True,
        read_config=tfds.ReadConfig(
            shuffle_seed=shuffle_seed,
            enable_ordering_guard=False))
    tfds_images = tuple(tfds.as_numpy(tfds_dataset)['image'])

    md_dataset = tf.data.TFRecordDataset(
        f'{FLAGS.meta_dataset_path}/{source}/{label + offset}.tfrecords'
    ).map(test_utils.parse_example).batch(10_000, drop_remainder=False)
    md_data, = tfds.as_numpy(md_dataset)
    md_images = tuple(md_data['image'])

    # The order is not the same (`shuffle_files = True`)...
    self.assertNotEqual(tfds_images, md_images)
    # ... but the contents are the same.
    self.assertSameElements(tfds_images, md_images)
示例#8
0
    def load_tfds(self):
        logger.info('Using TFDS to load data.')

        set_hard_limit_num_open_files()

        self._builder = tfds.builder(self._dataset_name,
                                     data_dir=self._dataset_dir)

        self._builder.download_and_prepare()

        decoders = {}

        if self._skip_decoding:
            decoders['image'] = tfds.decode.SkipDecoding()

        read_config = tfds.ReadConfig(interleave_cycle_length=64,
                                      interleave_block_length=1)

        dataset = self._builder.as_dataset(split=self._split,
                                           as_supervised=True,
                                           shuffle_files=True,
                                           decoders=decoders,
                                           read_config=read_config)

        return dataset
示例#9
0
 def __init__(self,
              version: Literal[10, 20, 100],
              quantize_bits: int = 8,
              seed: int = 1):
     assert version in (10, 20, 100), \
       "Only support CIFAR-10, CIFAR-20 and CIFAR-100"
     self.version = version
     if version == 10:
         dsname = 'cifar10'
     else:
         dsname = 'cifar100'
     self.train, self.valid, self.test = tfds.load(
         name=dsname,
         split=['train[:48000]', 'train[48000:]', 'test'],
         # as_supervised=True,
         read_config=tfds.ReadConfig(shuffle_seed=seed,
                                     shuffle_reshuffle_each_iteration=True),
         shuffle_files=True,
         with_info=False,
     )
     if version in (10, 100):
         process = lambda dat: (dat['image'], dat['label'])
     elif version == 20:
         process = lambda dat: (dat['image'], dat['coarse_label'])
     self.train = self.train.map(process)
     self.valid = self.valid.map(process)
     self.test = self.test.map(process)
示例#10
0
def _load_dataset(split,
                  should_shuffle=False,
                  shuffle_seed=None,
                  shuffle_buffer_size=None):
    """Loads a dataset split from TFDS."""
    if should_shuffle:
        assert shuffle_seed is not None and shuffle_buffer_size is not None
        file_shuffle_seed, dataset_shuffle_seed = jax.random.split(
            shuffle_seed)
        file_shuffle_seed = file_shuffle_seed[0]
        dataset_shuffle_seed = dataset_shuffle_seed[0]
    else:
        file_shuffle_seed = None
        dataset_shuffle_seed = None

    read_config = tfds.ReadConfig(add_tfds_id=True,
                                  shuffle_seed=file_shuffle_seed)
    dataset = tfds.load('ogbg_molpcba',
                        split=split,
                        shuffle_files=should_shuffle,
                        read_config=read_config)

    if should_shuffle:
        dataset = dataset.shuffle(seed=dataset_shuffle_seed,
                                  buffer_size=shuffle_buffer_size)
        dataset = dataset.repeat()

    return dataset
示例#11
0
def large_imagenet(data_dir=None):
    """Imagenet dataset as a `tf.data.Dataset` pipeline.

    Note: JPEG decoding is skipped while loading this dataset
    and hence binary strings representing each image are returned.
    This is to improve the pipeline performance when additional
    transformations like random crops are used for training.

    Dataset homepage: http://www.image-net.org/

    Args:
        data_dir: Directory read/write data from TFDS
    """
    ds, info = tfds.load('imagenet2012',
                         shuffle_files=True,
                         read_config=tfds.ReadConfig(skip_prefetch=True),
                         decoders={'image': tfds.decode.SkipDecoding()},
                         as_supervised=True,
                         with_info=True,
                         data_dir=data_dir)
    ds_dict = {
        "train_ds": ds['train'],
        "val_ds": ds['validation'],
        "info": info
    }

    if 'minival' in info.splits:
        ds_dict['minival_ds'] = ds['minival']

    return ds_dict
示例#12
0
def init_data():
    """
    Initialize data.
    """
    (ds_train, ds_test), ds_info = tfds.load('mnist',
                                             split=['train', 'test'],
                                             shuffle_files=True,
                                             with_info=True,
                                             as_supervised=True,
                                             read_config=tfds.ReadConfig(shuffle_seed=parse_args.seed,
                                                                         try_autocache=False))

    num_train = ds_info.splits['train'].num_examples
    num_test = ds_info.splits['test'].num_examples

    # make sure all batches are the same size to minimize jit compilation cache
    assert num_train % parse_args.batch_size == 0
    num_batches = num_train // parse_args.batch_size
    assert num_test % parse_args.test_batch_size == 0
    num_test_batches = num_test // parse_args.test_batch_size

    # make sure we always save the model on the last iteration
    assert num_batches * parse_args.nepochs % parse_args.save_freq == 0

    ds_train = ds_train.cache().repeat().shuffle(1000, seed=seed).batch(parse_args.batch_size)
    ds_test_eval = ds_test.batch(parse_args.test_batch_size).repeat()

    ds_train, ds_test_eval = tfds.as_numpy(ds_train), tfds.as_numpy(ds_test_eval)

    meta = {
        "num_batches": num_batches,
        "num_test_batches": num_test_batches
    }

    return ds_train, ds_test_eval, meta
示例#13
0
    def load_tfds(self) -> tf.data.Dataset:
        """Return a dataset loading files from TFDS."""

        logging.info('Using TFDS to load data.')
        builder = tfds.builder(self.config.name, data_dir=self.config.data_dir)

        if self.config.download:
            builder.download_and_prepare()

        decoders = {}

        if self.config.skip_decoding:
            decoders['image'] = tfds.decode.SkipDecoding()

        read_config = tfds.ReadConfig(interleave_cycle_length=10,
                                      interleave_block_length=1,
                                      input_context=self.input_context)

        dataset = builder.as_dataset(split=self.config.split,
                                     as_supervised=True,
                                     shuffle_files=True,
                                     decoders=decoders,
                                     read_config=read_config)

        return dataset
示例#14
0
 def load_shard(self, file_instruction):
   """Returns a dataset for a single shard of the TFDS TFRecord files."""
   ds = self.builder._tfrecords_reader.read_files(  # pylint:disable=protected-access
       [file_instruction],
       read_config=tfds.ReadConfig(),
       shuffle_files=False)
   return ds
示例#15
0
def make_dataset(batch_size,
                 eval_batch_size,
                 transform,
                 data_dir=None,
                 drop_remainder=None):
    n_train, n_val = NUM_EXAMPLES['train'], NUM_EXAMPLES['validation']
    steps_per_epoch = n_train // batch_size
    if drop_remainder:
        val_steps = n_val // eval_batch_size
    else:
        val_steps = math.ceil(n_val / eval_batch_size)

    read_config = tfds.ReadConfig(try_autocache=False, skip_prefetch=True)
    ds_train = tfds.load("coco/2017",
                         split=f"train",
                         data_dir=data_dir,
                         shuffle_files=True,
                         read_config=read_config)
    ds_val = tfds.load("coco/2017",
                       split=f"validation",
                       data_dir=data_dir,
                       shuffle_files=False,
                       read_config=read_config)
    ds_train = prepare(ds_train,
                       batch_size,
                       transform(training=True),
                       training=True,
                       repeat=False)
    ds_val = prepare(ds_val,
                     eval_batch_size,
                     transform(training=False),
                     training=False,
                     repeat=False,
                     drop_remainder=drop_remainder)
    return ds_train, ds_val, steps_per_epoch, val_steps
示例#16
0
 def _load_tfds(self, *, split, shuffle_seed):
   return tfds.load(
       'cifar10',
       split={'train': 'train', 'eval': 'test'}[split],
       shuffle_files=shuffle_seed is not None,
       read_config=None if shuffle_seed is None else tfds.ReadConfig(
           shuffle_seed=shuffle_seed))
示例#17
0
def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
               tfds_split: Text,
               tfds_skip_decoding_feature: Text,
               tfds_as_supervised: bool,
               input_context: Optional[tf.distribute.InputContext] = None,
               seed: Optional[Union[int, tf.Tensor]] = None,
               is_training: bool = False,
               cache: bool = False,
               cycle_length: Optional[int] = None,
               block_length: Optional[int] = None) -> tf.data.Dataset:
    """Reads a dataset from tfds."""
    # No op if exist.
    tfds_builder.download_and_prepare()

    read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length,
                                  interleave_block_length=block_length,
                                  input_context=input_context,
                                  shuffle_seed=seed)
    decoders = {}
    if tfds_skip_decoding_feature:
        for skip_feature in tfds_skip_decoding_feature.split(','):
            decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
    dataset = tfds_builder.as_dataset(split=tfds_split,
                                      shuffle_files=is_training,
                                      as_supervised=tfds_as_supervised,
                                      decoders=decoders,
                                      read_config=read_config)

    if is_training and not cache:
        dataset = dataset.repeat()
    return dataset
示例#18
0
    def _read_tfds(
        self,
        input_context: Optional[tf.distribute.InputContext] = None
    ) -> tf.data.Dataset:
        """Reads a dataset from tfds."""
        # No op if exist.
        self._tfds_builder.download_and_prepare()

        read_config = tfds.ReadConfig(
            interleave_cycle_length=self._cycle_length,
            interleave_block_length=self._block_length,
            input_context=input_context,
            shuffle_seed=self._seed)
        decoders = {}
        if self._tfds_skip_decoding_feature:
            for skip_feature in self._tfds_skip_decoding_feature.split(','):
                decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
        dataset = self._tfds_builder.as_dataset(
            split=self._tfds_split,
            shuffle_files=self._is_training,
            as_supervised=self._tfds_as_supervised,
            decoders=decoders,
            read_config=read_config)

        # If cache is enabled, we will call `repeat()` later after `cache()`.
        if self._is_training and not self._cache:
            dataset = dataset.repeat()
        return dataset
示例#19
0
  def __init__(self, image_size: Optional[int] = 28, seed: int = 1):
    train, valid, test = tfds.load(
        name='omniglot',
        split=['train[:90%]', 'train[90%:]', 'test'],
        read_config=tfds.ReadConfig(shuffle_seed=seed,
                                    shuffle_reshuffle_each_iteration=True),
        as_supervised=True,
    )

    if image_size is None:
      image_size = 105
    image_size = int(image_size)
    if image_size != 105:

      @tf.function
      def resize(x, y):
        x = tf.image.resize(x,
                            size=(image_size, image_size),
                            method=tf.image.ResizeMethod.BILINEAR,
                            preserve_aspect_ratio=True,
                            antialias=True)
        y = tf.cast(y, dtype=tf.float32)
        return x, y

      train = train.map(resize, tf.data.AUTOTUNE)
      valid = valid.map(resize, tf.data.AUTOTUNE)
      test = test.map(resize, tf.data.AUTOTUNE)

    self.train = train
    self.valid = valid
    self.test = test
    self._image_size = image_size
    def _read_tfds(
        self,
        input_context: Optional[tf.distribute.InputContext] = None
    ) -> tf.data.Dataset:
        """Reads a dataset from tfds."""
        if self._tfds_download:
            self._tfds_builder.download_and_prepare()

        read_config = tfds.ReadConfig(
            interleave_cycle_length=self._cycle_length,
            interleave_block_length=self._block_length,
            input_context=input_context)
        decoders = {}
        if self._tfds_skip_decoding_feature:
            for skip_feature in self._tfds_skip_decoding_feature.split(','):
                decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
        dataset = self._tfds_builder.as_dataset(
            split=self._tfds_split,
            shuffle_files=self._is_training,
            as_supervised=self._tfds_as_supervised,
            decoders=decoders,
            read_config=read_config)

        if self._is_training:
            dataset = dataset.repeat()
        return dataset
示例#21
0
 def __init__(self):
     import tensorflow_datasets as tfds
     train_examples, val_examples = tfds.load(
         'math_dataset/arithmetic__mul',
         split=['train', 'test'],
         read_config=tfds.ReadConfig(shuffle_seed=seed,
                                     shuffle_reshuffle_each_iteration=True),
         as_supervised=True)
示例#22
0
文件: helper.py 项目: sbl1996/hanser
def load(name: str, *, split, data_dir=None, shuffle_files: bool = False):
    return tfds.load(name,
                     split=split,
                     data_dir=data_dir,
                     download=False,
                     shuffle_files=shuffle_files,
                     read_config=tfds.ReadConfig(try_autocache=False,
                                                 skip_prefetch=True))
示例#23
0
 def _load_tfds(self, *, split, shuffle_seed):
   return tfds.load(
       'imagenet2012',
       split={'train': 'train', 'eval': 'validation'}[split],
       shuffle_files=shuffle_seed is not None,
       read_config=None if shuffle_seed is None else tfds.ReadConfig(
           shuffle_seed=shuffle_seed),
       decoders={'image': tfds.decode.SkipDecoding()})
示例#24
0
def load_dataset(split: str, batch_size: int) -> Iterator[Batch]:
  ds = tfds.load("binarized_mnist", split=split, shuffle_files=True,
                 read_config=tfds.ReadConfig(shuffle_seed=FLAGS.random_seed))
  ds = ds.shuffle(buffer_size=10 * batch_size, seed=FLAGS.random_seed)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(buffer_size=5)
  ds = ds.repeat()
  return iter(tfds.as_numpy(ds))
示例#25
0
 def __init__(self):
     super(MNIST, self).__init__()
     self.train, self.valid, self.test = tfds.load(
         name='mnist',
         split=['train[:55000]', 'train[55000:]', 'test'],
         read_config=tfds.ReadConfig(shuffle_seed=1,
                                     shuffle_reshuffle_each_iteration=True),
         as_supervised=True)
示例#26
0
    def _build_autocached_info(self, builder: tfds.core.DatasetBuilder):
        """Returns the auto-cache information string."""
        always_cached = {}
        never_cached = {}
        unshuffle_cached = {}
        for split_name in sorted(builder.info.splits.keys()):
            split_name = str(split_name)
            cache_shuffled = builder._should_cache_ds(  # pylint: disable=protected-access
                split_name,
                shuffle_files=True,
                read_config=tfds.ReadConfig())
            cache_unshuffled = builder._should_cache_ds(  # pylint: disable=protected-access
                split_name,
                shuffle_files=False,
                read_config=tfds.ReadConfig())

            if all((cache_shuffled, cache_unshuffled)):
                always_cached[split_name] = None
            elif not any((cache_shuffled, cache_unshuffled)):
                never_cached[split_name] = None
            else:  # Dataset is only cached when shuffled_files is False
                assert not cache_shuffled and cache_unshuffled
                unshuffle_cached[split_name] = None

        if not len(builder.info.splits) or not builder.info.dataset_size:  # pylint: disable=g-explicit-length-test
            autocached_info = 'Unknown'
        elif len(always_cached) == len(builder.info.splits.keys()):
            autocached_info = 'Yes'  # All splits are auto-cached.
        elif len(never_cached) == len(builder.info.splits.keys()):
            autocached_info = 'No'  # Splits never auto-cached.
        else:  # Some splits cached, some not.
            autocached_info_parts = []
            if always_cached:
                split_names_str = ', '.join(always_cached)
                autocached_info_parts.append(
                    'Yes ({})'.format(split_names_str))
            if never_cached:
                split_names_str = ', '.join(never_cached)
                autocached_info_parts.append('No ({})'.format(split_names_str))
            if unshuffle_cached:
                split_names_str = ', '.join(unshuffle_cached)
                autocached_info_parts.append(
                    'Only when `shuffle_files=False` ({})'.format(
                        split_names_str))
            autocached_info = ', '.join(autocached_info_parts)
        return autocached_info
示例#27
0
def test_mocking_add_tfds_id():
  read_config = tfds.ReadConfig(add_tfds_id=True)
  ds = tfds.load('mnist', split='train', read_config=read_config)
  assert ds.element_spec == {
      'tfds_id': tf.TensorSpec(shape=(), dtype=tf.string),
      'image': tf.TensorSpec(shape=(28, 28, 1), dtype=tf.uint8),
      'label': tf.TensorSpec(shape=(), dtype=tf.int64),
  }
  list(ds.take(3))  # Iteration should work
 def load_shard(self, file_instruction, shuffle_files=False, seed=None):
     """Returns a dataset for a single shard of the TFDS TFRecord files."""
     # pytype:disable=attribute-error
     ds = self.builder._tfrecords_reader.read_files(  # pylint:disable=protected-access
         [file_instruction],
         read_config=tfds.ReadConfig(shuffle_seed=seed),
         shuffle_files=shuffle_files)
     # pytype:enable=attribute-error
     return ds
示例#29
0
 def _load_tfds(self, *, split, shuffle_seed):
   tfds_name = {'church': 'lsun/church_outdoor',
                'bedroom': 'lsun/bedroom'}[self._subset]
   return tfds.load(
       tfds_name,
       split={'train': 'train', 'eval': 'validation'}[split],
       shuffle_files=shuffle_seed is not None,
       read_config=None if shuffle_seed is None else tfds.ReadConfig(
           shuffle_seed=shuffle_seed),
       decoders={'image': tfds.decode.SkipDecoding()})
 def load_shard(self, file_instruction):
     """Returns a dataset for a single shard of the TFDS TFRecord files."""
     ds = self.builder._tfrecords_reader.read_files(  # pylint:disable=protected-access
         tfds.core.tfrecords_reader.FileInstructions(
             file_instructions=[file_instruction],
             num_examples_per_shard=None,
         ),
         read_config=tfds.ReadConfig(),
         shuffle_files=False)
     return ds