Ejemplo n.º 1
0
 def test_map_seed_manager(self):
   utils._NEXT_MAP_SEED = None
   self.assertIsNone(utils._NEXT_MAP_SEED)
   with utils.map_seed_manager(42):
     self.assertEqual(utils._NEXT_MAP_SEED, 42)
     with utils.map_seed_manager(410):
       self.assertEqual(utils._NEXT_MAP_SEED, 410)
       utils._NEXT_MAP_SEED += 10
       self.assertEqual(utils._NEXT_MAP_SEED, 420)
     utils._NEXT_MAP_SEED += 10
     self.assertEqual(utils._NEXT_MAP_SEED, 52)
   self.assertIsNone(utils._NEXT_MAP_SEED)
  def preprocess_postcache(
      self,
      dataset: tf.data.Dataset,
      sequence_length: Mapping[str, int],
      seed: Optional[int] = None
    ) -> tf.data.Dataset:
    """Runs preprocessing steps after the optional CacheDatasetPlaceholder.

    Args:
      dataset: a tf.data.Dataset
      sequence_length: dict mapping feature key to int length for that feature.
        If None, the features will not be truncated.
      seed: an optional random seed for deterministic preprocessing.
    Returns:
      a tf.data.Dataset
    """
    # Skip a sufficient number of seeds to avoid duplicating any from pre-cache
    # preprocessing.
    seed = None if seed is None else 42 * self._cache_step_idx
    with utils.map_seed_manager(seed):
      dataset = self._preprocess_dataset(
          dataset,
          self._preprocessors[self._cache_step_idx + 1:],
          sequence_length=sequence_length,
      )
    dataset = self._validate_dataset(
        dataset,
        expected_output_type=tf.int64,
        expected_output_rank=1,
        error_label="preprocessing",
        ensure_no_eos=True)
    return dataset
  def preprocess_precache(
      self,
      dataset: tf.data.Dataset,
      seed: Optional[int] = None
    ) -> tf.data.Dataset:
    """Runs preprocessing steps before the optional CacheDatasetPlaceholder."""
    if not self.supports_caching:
      return dataset

    with utils.map_seed_manager(seed):
      return self._preprocess_dataset(
          dataset,
          self._preprocessors[:self._cache_step_idx],
      )