Exemplo n.º 1
0
 def testGlobListShardedFilePatterns(self, specs, expected_files):
     # First, create all expected_files so Glob will work later.
     expected_full_files = [
         test_utils.test_tmpfile(f, '') for f in expected_files
     ]
     # Create the full spec names. This one doesn't create the files.
     full_specs = ','.join(
         [test_utils.test_tmpfile(spec) for spec in specs.split(',')])
     self.assertEqual(sorted(set(expected_full_files)),
                      io.glob_list_sharded_file_patterns(full_specs))
Exemplo n.º 2
0
def get_one_example_from_examples_path(source, proto=None):
    """Get the first record from `source`.

  Args:
    source: str. A pattern or a comma-separated list of patterns that represent
      file names.
    proto: A proto class. proto.FromString() will be called on each serialized
      record in path to parse it.

  Returns:
    The first record, or None.
  """
    files = sharded_file_utils.glob_list_sharded_file_patterns(source)
    if not files:
        raise ValueError(
            'Cannot find matching files with the pattern "{}"'.format(source))
    for f in files:
        try:
            return tfrecord.read_tfrecords(f, proto=proto).next()
        except StopIteration:
            # Getting a StopIteration from one next() means source_path is empty.
            # Move on to the next one to try to get one example.
            pass
    return None
Exemplo n.º 3
0
  def __init__(
      self,
      mode,
      input_file_spec,
      num_examples=None,
      num_classes=dv_constants.NUM_CLASSES,
      max_examples=None,
      tensor_shape=None,
      name=None,
      use_tpu=False,
      input_read_threads=_DEFAULT_INPUT_READ_THREADS,
      input_map_threads=_DEFAULT_INPUT_MAP_THREADS,
      shuffle_buffer_size=_DEFAULT_SHUFFLE_BUFFER_ELEMENTS,
      initial_shuffle_buffer_size=_DEFAULT_INITIAL_SHUFFLE_BUFFER_ELEMENTS,
      prefetch_dataset_buffer_size=_DEFAULT_PREFETCH_BUFFER_BYTES,
      sloppy=True,
      list_files_shuffle=True,
      debugging_true_label_mode=False):
    """Create an DeepVariantInput object, usable as an `input_fn`.

    Args:
      mode: the mode string (from `tf.estimator.ModeKeys`).
      input_file_spec: the input filename for a tfrecord[.gz] file containing
        examples.  Can contain sharding designators.
      num_examples: the number of examples contained in the input file.
        Required for setting learning rate schedule in train/eval only.
      num_classes: The number of classes in the labels of
        this dataset. Currently defaults to DEFAULT_NUM_CLASSES.
      max_examples: The maximum number of examples to use. If None, all examples
        will be used. If not None, the first n = min(max_examples, num_examples)
        will be used. This works with training, and the n examples will repeat
        over and over.
      tensor_shape: None (which means we get the shape from the first example in
        source), or list of int [height, width, channel] for testing.
      name: string, name of the dataset.
      use_tpu: use code paths tuned for TPU, in particular protobuf encoding.
        Default False.
      input_read_threads: number of threads for reading data.  Default 32.
      input_map_threads: number of threads for mapping data.  Default 48.
      shuffle_buffer_size: size of the final shuffle buffer, in elements.
        Default 100.
      initial_shuffle_buffer_size: int; the size of the dataset.shuffle buffer
        in elements.  Default is 1024.
      prefetch_dataset_buffer_size: int; the size of the TFRecordDataset buffer
        in bytes.  Default is 16 * 1000 * 1000.
      sloppy: boolean, allow parallel_interleave to be sloppy.  Default True.
      list_files_shuffle: boolean, allow list_files to shuffle.  Default True.
      debugging_true_label_mode: boolean. If true, the input examples are
                                 created with "training" mode. We'll parse the
                                 'label' field even if the `mode` is PREDICT.
    Raises:
      ValueError: if `num_examples` not provided, in a context requiring it.
    """
    self.mode = mode
    self.input_file_spec = input_file_spec
    self.name = name
    self.num_examples = num_examples
    self.num_classes = num_classes
    self.max_examples = max_examples

    self.use_tpu = use_tpu
    self.sloppy = sloppy
    self.list_files_shuffle = list_files_shuffle
    self.input_read_threads = input_read_threads
    self.input_map_threads = input_map_threads
    self.shuffle_buffer_size = shuffle_buffer_size
    self.initial_shuffle_buffer_size = initial_shuffle_buffer_size
    self.prefetch_dataset_buffer_size = prefetch_dataset_buffer_size
    self.debugging_true_label_mode = debugging_true_label_mode
    self.feature_extraction_spec = self.features_extraction_spec_for_mode(
        mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL) or
        debugging_true_label_mode)

    if num_examples is None and mode != tf.estimator.ModeKeys.PREDICT:
      raise ValueError('num_examples argument required for DeepVariantInput'
                       'in TRAIN/EVAL modes.')

    if max_examples is not None:
      if max_examples <= 0:
        raise ValueError(
            'max_examples must be > 0 if not None. Got {}'.format(max_examples))
      # We update our num_examples in the situation where num_examples is set
      # (i.e., is not None) to the smaller of max_examples and num_examples.
      if self.num_examples is not None:
        self.num_examples = min(max_examples, self.num_examples)

    if tensor_shape:
      self.tensor_shape = tensor_shape
    else:
      self.tensor_shape = tf_utils.get_shape_from_examples_path(input_file_spec)
    self.input_files = sharded_file_utils.glob_list_sharded_file_patterns(
        self.input_file_spec)