コード例 #1
0
ファイル: base.py プロジェクト: mcx/uncertainty-baselines
        def load(self,
                 *,
                 preprocess_fn=None,
                 batch_size: int = -1) -> tf.data.Dataset:
            # Set up the in-distribution dataset using the provided dataset builder.
            if preprocess_fn:
                dataset_preprocess_fn = preprocess_fn
            else:
                dataset_preprocess_fn = (
                    self._in_distribution_dataset._create_process_example_fn())  # pylint: disable=protected-access
            dataset_preprocess_fn = ops.compose(dataset_preprocess_fn,
                                                _create_ood_label_fn(True))
            dataset = self._in_distribution_dataset.load(
                preprocess_fn=dataset_preprocess_fn, batch_size=batch_size)
            dataset = dataset.map(
                _remove_fingerprint_id_key(self._in_distribution_dataset))

            # Set up the OOD dataset using this class.
            if preprocess_fn:
                ood_dataset_preprocess_fn = preprocess_fn
            else:
                ood_dataset_preprocess_fn = (super(
                    _OodBaseDataset, self)._create_process_example_fn())
            ood_dataset_preprocess_fn = ops.compose(
                ood_dataset_preprocess_fn, _create_ood_label_fn(False))
            ood_dataset = super(_OodBaseDataset, self).load(
                preprocess_fn=ood_dataset_preprocess_fn, batch_size=batch_size)
            ood_dataset = ood_dataset.map(_remove_fingerprint_id_key(self))

            # Combine the two datasets.
            combined_dataset = dataset.concatenate(ood_dataset)
            if self._shuffle_datasets:
                combined_dataset = combined_dataset.shuffle(
                    self._shuffle_buffer_size)
            return combined_dataset
コード例 #2
0
    def load(self,
             *,
             preprocess_fn=None,
             batch_size: int = -1) -> tf.data.Dataset:
      # Set up the in-distribution dataset using the provided dataset builder.
      if preprocess_fn:
        dataset_preprocess_fn = preprocess_fn
      else:
        dataset_preprocess_fn = (
            self._in_distribution_dataset._create_process_example_fn())  # pylint: disable=protected-access
      dataset_preprocess_fn = ops.compose(
          dataset_preprocess_fn,
          _create_ood_label_fn(True))
      dataset = self._in_distribution_dataset.load(
          preprocess_fn=dataset_preprocess_fn,
          batch_size=batch_size)
      dataset = dataset.map(
          _remove_fingerprint_id_key(self._in_distribution_dataset))

      # Set up the OOD dataset using this class.
      if preprocess_fn:
        ood_dataset_preprocess_fn = preprocess_fn
      else:
        ood_dataset_preprocess_fn = (
            super(_OodBaseDataset, self)._create_process_example_fn())
      ood_dataset_preprocess_fn = ops.compose(
          ood_dataset_preprocess_fn,
          _create_ood_label_fn(False))
      ood_dataset = super(_OodBaseDataset, self).load(
          preprocess_fn=ood_dataset_preprocess_fn,
          batch_size=batch_size)
      ood_dataset = ood_dataset.map(_remove_fingerprint_id_key(self))

      # Combine the two datasets.
      try:
        combined_dataset = dataset.concatenate(ood_dataset)
      except TypeError:
        logging.info(
            'Two datasets have different types, concat feature and label only')

        def clean_keys(example):
          # only keep features and labels, remove the rest
          return {
              'features': example['features'],
              'labels': example['labels'],
              'is_in_distribution': example['is_in_distribution']
          }

        combined_dataset = dataset.map(clean_keys).concatenate(
            ood_dataset.map(clean_keys))
      if self._shuffle_datasets:
        combined_dataset = combined_dataset.shuffle(self._shuffle_buffer_size)
      return combined_dataset
コード例 #3
0
  def load(self, preprocess_fn: Optional[PreprocessFn]) -> tf.data.Dataset:
    if not preprocess_fn:
      preprocess_fn = self._default_preprocess_fn

    preprocess_fn = ops.compose(preprocess_fn, self.create_metadata)
    ds = self._dataset_builder.as_dataset(split=self._split,
                                          as_supervised=False)
    ds = ds.map(preprocess_fn)
    return ds.enumerate().map(_enumerated_to_metadata)
コード例 #4
0
  def load(self,
           preprocess_fn: Optional[PreprocessFn] = None) -> tf.data.Dataset:
    if not preprocess_fn:
      preprocess_fn = self._default_preprocess_fn

    preprocess_fn = ops.compose(preprocess_fn,
                                self.create_metadata)
    self._dataset_builder.download_and_prepare()
    ds = self._dataset_builder.as_dataset(
        split=self._split, as_supervised=False)
    return ds.map(preprocess_fn)
コード例 #5
0
  def load(self, preprocess_fn: Optional[PreprocessFn],
           batch_size: int) -> tf.data.Dataset:
    if not preprocess_fn:
      preprocess_fn = self._default_preprocess_fn

    preprocess_fn = ops.compose(preprocess_fn, self.create_metadata)
    ds = self._dataset_builder.as_dataset(split=self._split,
                                          as_supervised=False)
    # TODO(trandustin): Change to drop_remainder=False. For now, True aligns
    # with how results are currently measured in Uncertainty Baselines.
    ds_batched = ds.map(preprocess_fn).batch(batch_size, drop_remainder=True)
    return ds_batched.prefetch(tf.data.experimental.AUTOTUNE)
コード例 #6
0
  def load(self, preprocess_fn: Optional[PreprocessFn]) -> tf.data.Dataset:
    if not preprocess_fn:
      preprocess_fn = self._default_preprocess_fn

    preprocess_fn = ops.compose(preprocess_fn, self.create_metadata)
    ds = self._dataset_builder.as_dataset(split=self._split,
                                          as_supervised=False)
    ds = ds.map(preprocess_fn)

    def enumerated_to_metadata(position, features):
      features["metadata"]["element_id"] = position
      return features

    return ds.enumerate().map(enumerated_to_metadata)
コード例 #7
0
  def load(self, preprocess_fn: Optional[PreprocessFn],
           batch_size: int) -> tf.data.Dataset:
    if not preprocess_fn:
      preprocess_fn = self._default_preprocess_fn

    def create_element_id(features: Dict[str, Any]):
      """Hash the element id to compute a unique id."""
      assert "element_id" not in features, \
             "`element_id` should not be already present in the feature set."
      fingerprint_feature = features[self._fingerprint_key]
      features["element_id"] = ops.fingerprint_int64(fingerprint_feature)
      return features

    preprocess_fn = ops.compose(preprocess_fn, create_element_id,
                                self.create_metadata)
    self._dataset_builder.download_and_prepare()
    ds = self._dataset_builder.as_dataset(
        split=self._split, as_supervised=False)
    ds_batched = ds.map(preprocess_fn).batch(batch_size, drop_remainder=False)
    return ds_batched.prefetch(tf.data.experimental.AUTOTUNE)
コード例 #8
0
  def _load(self,
            *,
            preprocess_fn: Optional[PreProcessFn] = None,
            batch_size: int = -1) -> tf.data.Dataset:
    """Transforms the dataset from builder.as_dataset() to batch, repeat, etc.

    Note that we do not handle replication/sharding here, because the
    DistributionStrategy experimental_distribute_dataset() will shard the input
    files for us.

    Args:
      preprocess_fn: an optional preprocessing function, if not provided then a
        subclass must define _create_process_example_fn() which will be used to
        preprocess the data.
      batch_size: the batch size to use.

    Returns:
      A tf.data.Dataset of elements that are a dict with keys 'features' and
      'labels' and their corresponding Tensor values.
    """
    if batch_size <= 0:
      raise ValueError(
          'Must provide a positive batch size, received {}.'.format(batch_size))

    self._seed, self._shuffle_seed = tf.random.experimental.stateless_split(
        self._seed, num=2)

    if self._download_data:
      self._dataset_builder.download_and_prepare()
    dataset = self._dataset_builder.as_dataset(
        split=self._split, decoders=self._decoders)

    # Possibly cache the original dataset before preprocessing is applied.
    if self._cache:
      dataset = dataset.cache()

    # This must be done *before* repeating the dataset so that each example has
    # a unique and stable fingerprint key.
    if self._add_fingerprint_key:
      dataset = dataset.enumerate()
      add_fingerprint_key_fn = self._add_enumerate_id(self._fingerprint_key)
      dataset = dataset.map(
          add_fingerprint_key_fn,
          num_parallel_calls=self._num_parallel_parser_calls)

    # If we are truncating the validation/test dataset (self._drop_remainder)
    # we may as well repeat to speed things up.
    # TODO(nband): benchmark.
    # TODO(trandustin): Make this differing behavior consistent with other
    # ub.datasets.
    if (self.name in _drd_datasets and not self._is_training and
        self._drop_remainder):
      dataset = dataset.repeat()
      logging.info('Repeating dataset %s (training mode: %s).', self.name,
                   self._is_training)

    # Shuffle and repeat only for the training split.
    if self._is_training:
      dataset = dataset.shuffle(
          self._shuffle_buffer_size,
          seed=tf.cast(self._shuffle_seed[0], tf.int64),
          reshuffle_each_iteration=True)
      dataset = dataset.repeat()

    # Enumerate the dataset to generate a unique per-example, per-step id, that
    # is then added to the feature dict as `self._enumerate_id_key`.
    # Note that this is distinct from just a per-example id that is used by
    # Robustness Metrics to identify examples, because we want an id that is
    # different for each step so that we can fold it into a source of randomness
    # for deterministic random preprocessing.
    # This must be done *after* repeating the dataset so that each example has a
    # different key per-step.
    dataset = dataset.enumerate()
    add_per_step_id_key_fn = self._add_enumerate_id(self._enumerate_id_key)

    if preprocess_fn is None:
      preprocess_fn = self._create_process_example_fn()

    # `self._create_element_id` must come before `preprocess_fn` so that we
    # guarantee the field with key `self._fingerprint_key` is still present
    # (many preprocess_fn's may not return it).
    preprocess_fn = ops.compose(
        add_per_step_id_key_fn, self._create_element_id, preprocess_fn)
    dataset = dataset.map(
        preprocess_fn,
        num_parallel_calls=self._num_parallel_parser_calls)

    # Note that unless the default value of `drop_remainder=True` is overriden
    # in `__init__`, we always drop the last batch when the batch size does not
    # evenly divide the number of examples.
    # TODO(znado): add padding to last partial eval batch.
    dataset = dataset.batch(batch_size, drop_remainder=self._drop_remainder)

    process_batch_fn = self._create_process_batch_fn(batch_size)  # pylint: disable=assignment-from-none
    if process_batch_fn:
      dataset = dataset.map(
          process_batch_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if not self._is_training and self.name not in _drd_datasets:
      dataset = dataset.cache()
    else:
      if not self._cache:
        logging.info(
            'Not caching dataset %s (training mode: %s).',
            self.name,
            self._is_training)

    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    # The AutoSharding policy in DistributionStrategy defaults to AUTO, which
    # will fallback to DATA if it can, which is safe to do but possibly
    # wasteful compared to `distribute_datasets_from_function`.
    options = tf.data.Options()
    # Optimize dataset performance.
    # Keep commented out, unclear if will always improve performance.
    # options.experimental_optimization.parallel_batch = True
    options.experimental_optimization.map_fusion = True
    options.experimental_optimization.map_parallelization = True
    options.experimental_threading.private_threadpool_size = 48
    options.experimental_threading.max_intra_op_parallelism = 1
    dataset = dataset.with_options(options)
    return dataset
コード例 #9
0
  def load(self,
           *,
           preprocess_fn: PreProcessFn = None,
           batch_size: int = -1) -> tf.data.Dataset:
    """Transforms the dataset from builder.as_dataset() to batch, repeat, etc.

    Note that we do not handle replication/sharding here, because the
    DistributionStrategy experimental_distribute_dataset() will shard the input
    files for us.

    Args:
      preprocess_fn: an optional preprocessing function, if not provided then a
        subclass must define _create_process_example_fn() which will be used to
        preprocess the data.
      batch_size: the batch size to use.

    Returns:
      A tf.data.Dataset of elements that are a dict with keys 'features' and
      'labels' and their corresponding Tensor values.
    """
    if batch_size <= 0:
      raise ValueError(
          'Must provide a positive batch size, received {}.'.format(batch_size))

    if self._download_data:
      self._dataset_builder.download_and_prepare(
          download_dir=self._dataset_builder.data_dir)
    dataset = self._dataset_builder.as_dataset(self._split)

    # Map the parser over the dataset.
    if preprocess_fn is None:
      preprocess_fn = self._create_process_example_fn()
    if self._add_enumerate:
      # If necessary, enumerate the dataset to generate a unique per-example id,
      # that is then added to the feature dict in
      # `self._create_enumerate_preprocess_fn` with key `self._fingerprint_key`.
      dataset = dataset.enumerate()
      preprocess_fn = self._create_enumerate_preprocess_fn(preprocess_fn)
    preprocess_fn = ops.compose(preprocess_fn, self._create_element_id)
    dataset = dataset.map(
        preprocess_fn,
        num_parallel_calls=self._num_parallel_parser_calls)

    # Shuffle and repeat only for the training split.
    if self._is_training:
      dataset = dataset.shuffle(self._shuffle_buffer_size)
      dataset = dataset.repeat()

    # Note that unless the default value of `drop_remainder=True` is overriden
    # in `__init__`, we always drop the last batch when the batch size does not
    # evenly divide the number of examples.
    # TODO(znado): add padding to last partial eval batch.
    dataset = dataset.batch(batch_size, drop_remainder=self._drop_remainder)

    dataset = dataset.prefetch(-1)

    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = (
        tf.data.experimental.AutoShardPolicy.OFF)
    dataset = dataset.with_options(options)
    return dataset