Exemplo n.º 1
0
    def get_dataset(config_params, file_pattern, dataset_type, mode):
        is_training = (mode == ModeKeys.TRAIN)
        if dataset_type == 'tfrecord':
            dataset_cls = tf.data.TFRecordDataset
            parser_fn = factory.parser_generator(config_params, mode)
        else:
            raise ValueError('Dataset type %s is not supported.' %
                             dataset_type)

        if ',' in file_pattern:
            dataset = tf.data.Dataset.from_tensor_slices(
                file_pattern.split(','))
        else:
            dataset = tf.data.Dataset.list_files(file_pattern,
                                                 shuffle=is_training)
        if is_training:
            dataset = dataset.repeat()

        dataset = dataset.apply(
            tf.data.experimental.parallel_interleave(
                lambda file_name: dataset_cls(file_name).prefetch(1),
                cycle_length=32,
                sloppy=is_training))

        if is_training:
            dataset = dataset.shuffle(64)

        return dataset, parser_fn
Exemplo n.º 2
0
    def dataset_fn(params, mode):
        """Creates and returns a pre-batched tf.data.Dataset."""
        del params
        is_training = (mode == ModeKeys.TRAIN)
        if dataset_type == 'tfrecord':
            dataset_cls = tf.data.TFRecordDataset
            parser_fn = factory.parser_generator(config_params, mode)
        else:
            raise ValueError('Dataset type %s is not supported.' %
                             dataset_type)
        dataset = tf.data.Dataset.list_files(file_pattern, shuffle=is_training)
        if is_training:
            dataset = dataset.repeat()

        dataset = dataset.apply(
            tf.data.experimental.parallel_interleave(
                lambda file_name: dataset_cls(file_name).prefetch(1),
                cycle_length=32,
                sloppy=is_training))

        if is_training:
            dataset = dataset.shuffle(64)

        # Parses the fetched records to input tensors for model function.
        dataset = dataset.map(parser_fn, num_parallel_calls=64)

        return dataset
Exemplo n.º 3
0
 def __init__(self, file_pattern, params, mode):
   self._file_pattern = file_pattern
   self._mode = mode
   self._is_training = (mode == ModeKeys.TRAIN)
   self._parser_fn = factory.parser_generator(params, mode)
   self._dataset_fn = tf.data.TFRecordDataset
   self._transpose_input = hasattr(params, 'train') and hasattr(
       params.train, 'transpose_input') and params.train.transpose_input
Exemplo n.º 4
0
  def __init__(self, file_pattern, params, mode, dataset_type='tfrecord'):
    self._file_pattern = file_pattern
    self._mode = mode
    self._is_training = (mode == ModeKeys.TRAIN)
    if dataset_type == 'tfrecord':
      self._dataset_fn = tf.data.TFRecordDataset
      self._parser_fn = factory.parser_generator(params, mode)
    else:
      raise ValueError('Dataset type %s is not supported.' % dataset_type)

    self._transpose_input = params.train.transpose_input
    self._space_to_depth_block_size = params.architecture.space_to_depth_block_size
Exemplo n.º 5
0
  def __init__(self, file_pattern, params, mode, dataset_type='tfrecord'):
    self._file_pattern = file_pattern
    self._mode = mode
    self._is_training = (mode == ModeKeys.TRAIN)
    if dataset_type == 'tfrecord':
      self._dataset_fn = tf.data.TFRecordDataset
      self._parser_fn = factory.parser_generator(params, mode)
    else:
      raise ValueError('Dataset type %s is not supported.' % dataset_type)

    self._transpose_input = hasattr(params, 'train') and hasattr(
        params.train, 'transpose_input') and params.train.transpose_input
Exemplo n.º 6
0
    def __init__(self,
                 file_pattern: Text,
                 params: params_dict.ParamsDict,
                 mode: Text,
                 batch_size: int,
                 num_examples: Optional[int] = -1):
        """Initialize.

    Args:
      file_pattern: the file pattern for the data example (TFRecords).
      params: the parameter object for constructing example parser and model.
      mode: ModeKeys.TRAIN or ModeKeys.Eval
      batch_size: the data batch size.
      num_examples: If positive, only takes this number of examples and raise
        tf.errors.OutOfRangeError after that. If non-positive, it will be
        ignored.
    """
        assert file_pattern is not None
        assert mode is not None
        assert batch_size is not None
        self._file_pattern = file_pattern
        self._mode = mode
        self._is_training = (mode == ModeKeys.TRAIN)
        self._batch_size = batch_size
        self._num_examples = num_examples
        self._parser_fn = factory.parser_generator(params, mode)
        self._dataset_fn = tf.data.TFRecordDataset

        self._input_sharding = (not self._is_training)
        try:
            if self._is_training:
                self._input_sharding = params.train.input_sharding
            else:
                self._input_sharding = params.eval.input_sharding
        except AttributeError:
            pass
Exemplo n.º 7
0
 def __init__(self, file_pattern, params, mode):
     self._file_pattern = file_pattern
     self._mode = mode
     self._is_training = (mode == ModeKeys.TRAIN)
     self._parser_fn = factory.parser_generator(params, mode)
     self._dataset_fn = tf.data.TFRecordDataset