Ejemplo n.º 1
0
def input_reader_generator(params: cfg.DataConfig,
                           **kwargs) -> core_input_reader.InputReader:
  """Instantiates an input reader class according to the params.

  Args:
    params: A config_definitions.DataConfig object.
    **kwargs: Additional arguments passed to input reader initialization.

  Returns:
    An InputReader object.

  """
  if params.is_training and params.get('pseudo_label_data', False):
    return vision_input_reader.CombinationDatasetInputReader(
        params,
        pseudo_label_dataset_fn=dataset_fn_util.pick_dataset_fn(
            params.pseudo_label_data.file_type),
        **kwargs)
  else:
    return core_input_reader.InputReader(params, **kwargs)
Ejemplo n.º 2
0
  def __init__(self,
               params: cfg.DataConfig,
               dataset_fn=tf.data.TFRecordDataset,
               decoder_fn: Optional[Callable[..., Any]] = None,
               sample_fn: Optional[Callable[..., Any]] = None,
               parser_fn: Optional[Callable[..., Any]] = None,
               transform_and_batch_fn: Optional[Callable[
                   [tf.data.Dataset, Optional[tf.distribute.InputContext]],
                   tf.data.Dataset]] = None,
               postprocess_fn: Optional[Callable[..., Any]] = None):
    """Initializes an InputReader instance.

    Args:
      params: A config_definitions.DataConfig object.
      dataset_fn: A `tf.data.Dataset` that consumes the input files. For
        example, it can be `tf.data.TFRecordDataset`.
      decoder_fn: An optional `callable` that takes the serialized data string
        and decodes them into the raw tensor dictionary.
      sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
        input and outputs the transformed dataset. It performs sampling on the
        decoded raw tensors dict before the parser_fn.
      parser_fn: An optional `callable` that takes the decoded raw tensors dict
        and parse them into a dictionary of tensors that can be consumed by the
        model. It will be executed after decoder_fn.
      transform_and_batch_fn: An optional `callable` that takes a
        `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
        input, and returns a `tf.data.Dataset` object. It will be executed after
        `parser_fn` to transform and batch the dataset; if None, after
        `parser_fn` is executed, the dataset will be batched into per-replica
        batch size.
      postprocess_fn: A optional `callable` that processes batched tensors. It
        will be executed after batching.
    """
    if params.input_path and params.tfds_name:
      raise ValueError('At most one of `input_path` and `tfds_name` can be '
                       'specified, but got %s and %s.' %
                       (params.input_path, params.tfds_name))
    self._tfds_builder = None
    self._matched_files = []
    if params.input_path:
      self._matched_files = self._match_files(params.input_path)
    else:
      # Read dataset from TFDS.
      if not params.tfds_split:
        raise ValueError(
            '`tfds_name` is %s, but `tfds_split` is not specified.' %
            params.tfds_name)
      self._tfds_builder = tfds.builder(
          params.tfds_name, data_dir=params.tfds_data_dir)

    self._global_batch_size = params.global_batch_size
    self._is_training = params.is_training
    self._drop_remainder = params.drop_remainder
    self._shuffle_buffer_size = params.shuffle_buffer_size
    self._cache = params.cache
    self._cycle_length = params.cycle_length
    self._block_length = params.block_length
    self._deterministic = params.deterministic
    self._sharding = params.sharding
    self._tfds_split = params.tfds_split
    self._tfds_download = params.tfds_download
    self._tfds_as_supervised = params.tfds_as_supervised
    self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature

    self._dataset_fn = dataset_fn
    self._decoder_fn = decoder_fn
    self._sample_fn = sample_fn
    self._parser_fn = parser_fn
    self._transform_and_batch_fn = transform_and_batch_fn
    self._postprocess_fn = postprocess_fn
    # When tf.data service is enabled, each data service worker should get
    # different random seeds. Thus, we set `seed` to None.
    self._seed = (None
                  if params.enable_tf_data_service else _get_random_integer())

    self._enable_tf_data_service = (
        params.enable_tf_data_service and params.tf_data_service_address)
    self._tf_data_service_address = params.tf_data_service_address
    if self._enable_tf_data_service:
      # Add a random seed as the tf.data service job name suffix, so tf.data
      # service doesn't reuse the previous state if TPU worker gets preempted.
      self._tf_data_service_job_name = (
          params.tf_data_service_job_name + str(self.static_randnum))
      self._enable_round_robin_tf_data_service = params.get(
          'enable_round_robin_tf_data_service', False)
    def __init__(self,
                 params: cfg.DataConfig,
                 dataset_fn=tf.data.TFRecordDataset,
                 decoder_fn: Optional[Callable[..., Any]] = None,
                 sample_fn: Optional[Callable[..., Any]] = None,
                 parser_fn: Optional[Callable[..., Any]] = None,
                 transform_and_batch_fn: Optional[Callable[
                     [tf.data.Dataset, Optional[tf.distribute.InputContext]],
                     tf.data.Dataset]] = None,
                 postprocess_fn: Optional[Callable[..., Any]] = None):
        """Initializes an InputReader instance.

    Args:
      params: A config_definitions.DataConfig object.
      dataset_fn: A `tf.data.Dataset` that consumes the input files. For
        example, it can be `tf.data.TFRecordDataset`.
      decoder_fn: An optional `callable` that takes the serialized data string
        and decodes them into the raw tensor dictionary.
      sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
        input and outputs the transformed dataset. It performs sampling on the
        decoded raw tensors dict before the parser_fn.
      parser_fn: An optional `callable` that takes the decoded raw tensors dict
        and parse them into a dictionary of tensors that can be consumed by the
        model. It will be executed after decoder_fn.
      transform_and_batch_fn: An optional `callable` that takes a
        `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
        input, and returns a `tf.data.Dataset` object. It will be executed after
        `parser_fn` to transform and batch the dataset; if None, after
        `parser_fn` is executed, the dataset will be batched into per-replica
        batch size.
      postprocess_fn: A optional `callable` that processes batched tensors. It
        will be executed after batching.
    """
        if params.input_path and params.tfds_name:
            raise ValueError(
                'At most one of `input_path` and `tfds_name` can be '
                'specified, but got %s and %s.' %
                (params.input_path, params.tfds_name))
        self._tfds_builder = None
        self._matched_files = []
        if params.input_path:
            # Read dataset from files.
            usage = (
                '`input_path` should be either (1) a str indicating a file '
                'path/pattern, or (2) a str indicating multiple file '
                'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
                '"a,b,c", or (3) a list of str, each of which is a file '
                'path/pattern or multiple file paths/patterns separated by '
                'comma, but got: %s')
            if isinstance(params.input_path, str):
                input_path_list = [params.input_path]
            elif isinstance(params.input_path, (list, tuple)):
                if any(not isinstance(x, str) for x in params.input_path):
                    raise ValueError(usage % params.input_path)
                input_path_list = params.input_path
            else:
                raise ValueError(usage % params.input_path)

            for input_path in input_path_list:
                input_patterns = input_path.strip().split(',')
                for input_pattern in input_patterns:
                    input_pattern = input_pattern.strip()
                    if not input_pattern:
                        continue
                    if '*' in input_pattern or '?' in input_pattern:
                        tmp_matched_files = tf.io.gfile.glob(input_pattern)
                        if not tmp_matched_files:
                            raise ValueError('%s does not match any files.' %
                                             input_pattern)
                        self._matched_files.extend(tmp_matched_files)
                    else:
                        self._matched_files.append(input_pattern)

            if not self._matched_files:
                raise ValueError('%s does not match any files.' %
                                 params.input_path)
        else:
            # Read dataset from TFDS.
            if not params.tfds_split:
                raise ValueError(
                    '`tfds_name` is %s, but `tfds_split` is not specified.' %
                    params.tfds_name)
            self._tfds_builder = tfds.builder(params.tfds_name,
                                              data_dir=params.tfds_data_dir)

        self._global_batch_size = params.global_batch_size
        self._is_training = params.is_training
        self._drop_remainder = params.drop_remainder
        self._shuffle_buffer_size = params.shuffle_buffer_size
        self._cache = params.cache
        self._cycle_length = params.cycle_length
        self._block_length = params.block_length
        self._deterministic = params.deterministic
        self._sharding = params.sharding
        self._tfds_split = params.tfds_split
        self._tfds_download = params.tfds_download
        self._tfds_as_supervised = params.tfds_as_supervised
        self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature

        self._dataset_fn = dataset_fn
        self._decoder_fn = decoder_fn
        self._sample_fn = sample_fn
        self._parser_fn = parser_fn
        self._transform_and_batch_fn = transform_and_batch_fn
        self._postprocess_fn = postprocess_fn
        self._seed = _get_random_integer()

        self._enable_tf_data_service = (params.enable_tf_data_service
                                        and params.tf_data_service_address)
        self._tf_data_service_address = params.tf_data_service_address
        if self._enable_tf_data_service:
            # Add a random seed as the tf.data service job name suffix, so tf.data
            # service doesn't reuse the previous state if TPU worker gets preempted.
            self._tf_data_service_job_name = (params.tf_data_service_job_name +
                                              str(self.static_randnum))
            self._enable_round_robin_tf_data_service = params.get(
                'enable_round_robin_tf_data_service', False)
Ejemplo n.º 4
0
    def __init__(self,
                 params: cfg.DataConfig,
                 dataset_fn=tf.data.TFRecordDataset,
                 decoder_fn: Optional[Callable[..., Any]] = None,
                 combine_fn: Optional[Callable[..., Any]] = None,
                 sample_fn: Optional[Callable[..., Any]] = None,
                 parser_fn: Optional[Callable[..., Any]] = None,
                 transform_and_batch_fn: Optional[Callable[
                     [tf.data.Dataset, Optional[tf.distribute.InputContext]],
                     tf.data.Dataset]] = None,
                 postprocess_fn: Optional[Callable[..., Any]] = None):
        """Initializes an InputReader instance.

    Args:
      params: A config_definitions.DataConfig object.
      dataset_fn: A `tf.data.Dataset` that consumes the input files. For
        example, it can be `tf.data.TFRecordDataset`.
      decoder_fn: An optional `callable` that takes the serialized data string
        and decodes them into the raw tensor dictionary.
      combine_fn: An optional `callable` that takes a dictionarty of
        `tf.data.Dataset` objects as input and outputs a combined dataset. It
        will be executed after the decoder_fn and before the sample_fn.
      sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
        input and outputs the transformed dataset. It performs sampling on the
        decoded raw tensors dict before the parser_fn.
      parser_fn: An optional `callable` that takes the decoded raw tensors dict
        and parse them into a dictionary of tensors that can be consumed by the
        model. It will be executed after decoder_fn.
      transform_and_batch_fn: An optional `callable` that takes a
        `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
        input, and returns a `tf.data.Dataset` object. It will be executed after
        `parser_fn` to transform and batch the dataset; if None, after
        `parser_fn` is executed, the dataset will be batched into per-replica
        batch size.
      postprocess_fn: A optional `callable` that processes batched tensors. It
        will be executed after batching.
    """
        if params.input_path and params.tfds_name:
            raise ValueError(
                'At most one of `input_path` and `tfds_name` can be '
                'specified, but got %s and %s.' %
                (params.input_path, params.tfds_name))

        if isinstance(params.input_path,
                      cfg.base_config.Config) and combine_fn is None:
            raise ValueError(
                'A `combine_fn` is required if the `input_path` is a dictionary.'
            )

        self._tfds_builder = None
        self._matched_files = None
        if not params.input_path:
            # Read dataset from TFDS.
            if not params.tfds_split:
                raise ValueError(
                    '`tfds_name` is %s, but `tfds_split` is not specified.' %
                    params.tfds_name)
            self._tfds_builder = tfds.builder(params.tfds_name,
                                              data_dir=params.tfds_data_dir)
        else:
            self._matched_files = self.get_files(params.input_path)

        self._global_batch_size = params.global_batch_size
        self._is_training = params.is_training
        self._drop_remainder = params.drop_remainder
        self._shuffle_buffer_size = params.shuffle_buffer_size
        self._cache = params.cache
        self._cycle_length = params.cycle_length
        self._block_length = params.block_length
        self._deterministic = params.deterministic
        self._sharding = params.sharding
        self._tfds_split = params.tfds_split
        self._tfds_as_supervised = params.tfds_as_supervised
        self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature

        self._dataset_fn = dataset_fn
        self._decoder_fn = decoder_fn
        self._combine_fn = combine_fn
        self._sample_fn = sample_fn
        self._parser_fn = parser_fn
        self._transform_and_batch_fn = transform_and_batch_fn
        self._postprocess_fn = postprocess_fn
        self._seed = params.seed
        self._prefetch_buffer_size = (params.prefetch_buffer_size
                                      or tf.data.experimental.AUTOTUNE)

        # When tf.data service is enabled, each data service worker should get
        # different random seeds. Thus, we set `seed` to None.
        # Sharding should also be disabled because tf data service handles how
        # each worker shard data with `processing_mode` in distribute method.
        if params.enable_tf_data_service:
            self._seed = None
            self._sharding = False

        self._enable_tf_data_service = (params.enable_tf_data_service
                                        and params.tf_data_service_address)
        self._tf_data_service_address = params.tf_data_service_address
        self._enable_shared_tf_data_service_between_parallel_trainers = (
            params.enable_shared_tf_data_service_between_parallel_trainers)
        self._apply_tf_data_service_before_batching = (
            params.apply_tf_data_service_before_batching)
        self._trainer_id = params.trainer_id
        if self._enable_tf_data_service:
            # Add a random seed as the tf.data service job name suffix, so tf.data
            # service doesn't reuse the previous state if TPU worker gets preempted.
            # It's necessary to add global batch size into the tf data service job
            # name because when tuning batch size with vizier and tf data service is
            # also enable, the tf data servce job name should be different for
            # different vizier trials since once batch size is changed, from the
            # tf.data perspective, the dataset is a different instance, and a
            # different job name should be used for tf data service. Otherwise, the
            # model would read tensors from the incorrect tf data service job, which
            # would causes dimension mismatch on the batch size dimension.
            self._tf_data_service_job_name = (
                f'{params.tf_data_service_job_name}_bs{params.global_batch_size}_'
                f'{self.static_randnum}')
            self._enable_round_robin_tf_data_service = params.get(
                'enable_round_robin_tf_data_service', False)
            if self._enable_shared_tf_data_service_between_parallel_trainers:
                # When shared tf.data service is enabled, only a single tf.data service
                # instance should be created and shared between parallel trainers. If
                # the global batch size is different across trainers,
                # params.apply_tf_data_service_before_batching should be set to true
                # because tf.data service with different batch sizes will be considered
                # separate tf.data service instances.
                self._tf_data_service_job_name = (
                    f'{params.tf_data_service_job_name}_{self.static_randnum}')