Exemple #1
0
def _get_one_example_from_examples_path(source):
    """Reads one record from source."""
    # redacted
    # io_utils.read_tfrecord can read wildcard file patterns.
    # The source can be a comma-separated list.
    source_paths = source.split(',')
    for source_path in source_paths:
        files = tf.gfile.Glob(
            io_utils.NormalizeToShardedFilePattern(source_path))
        if not files:
            if len(source_paths) > 1:
                raise ValueError(
                    'Cannot find matching files with the pattern "{}" in "{}"'.
                    format(source_path, ','.join(source_paths)))
            else:
                raise ValueError(
                    'Cannot find matching files with the pattern "{}"'.format(
                        source_path))
        for f in files:
            try:
                return io_utils.read_tfrecords(f).next()
            except StopIteration:
                # Getting a StopIteration from one next() means source_path is empty.
                # Move on to the next one to try to get one example.
                pass
    return None
    def test_prepare_inputs(self, filename, expand_to_file_pattern):
        source_path = test_utils.test_tmpfile(filename)
        io_utils.write_tfrecords(self.examples, source_path)
        if expand_to_file_pattern:
            # Transform foo@3 to foo-?????-of-00003.
            source_path = io_utils.NormalizeToShardedFilePattern(source_path)

        with self.test_session() as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            ds = call_variants.prepare_inputs(source_path)
            _, variants, _ = data_providers.get_infer_batches(ds,
                                                              model=self.model,
                                                              batch_size=1)

            seen_variants = []
            try:
                while True:
                    seen_variants.extend(sess.run(variants))
            except tf.errors.OutOfRangeError:
                pass

            self.assertItemsEqual(self.variants,
                                  variant_utils.decode_variants(seen_variants))
Exemple #3
0
  def prediction_input_fn(self, params):
    """Implementation of `input_fn` contract for prediction mode.

    Args:
      params: a dict containing an integer value for key 'batch_size'.

    Returns:
      the tuple (features, labels), where:
        - features is a dict of Tensor-valued input features; keys populated
          are:
            'image'
            'variant'
            'alt_allele_indices'

          Aside from 'image', these may be encoded specially for TPU.
    """

    def load_dataset(filename):
      dataset = tf.data.TFRecordDataset(
          filename,
          buffer_size=self.prefetch_dataset_buffer_size,
          compression_type=compression_type)
      return dataset

    batch_size = params['batch_size']
    compression_type = tf_utils.compression_type_of_files(self.input_files)
    files = tf.data.Dataset.list_files(
        io_utils.NormalizeToShardedFilePattern(self.input_file_spec),
        shuffle=False,
    )
    tf.logging.info('self.input_read_threads=%d', self.input_read_threads)
    dataset = files.apply(
        tf.contrib.data.parallel_interleave(
            load_dataset,
            cycle_length=self.input_read_threads,
            sloppy=self.sloppy))
    tf.logging.info('self.input_map_threads=%d', self.input_map_threads)
    dataset = dataset.apply(
        tf.contrib.data.map_and_batch(
            self.parse_tfexample,
            batch_size=batch_size,
            num_parallel_batches=self.input_map_threads))
    dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
    return dataset
Exemple #4
0
def prepare_inputs(source_path, model, batch_size, num_readers=None):
  """Prepares image and encoded_variant ops.

  Reads image / encoded_variant tuples from source_path, extracting the image
  and encoded_variant tensors from source_path. The image is decoded from its
  png encoding and preprocessed with model.preprocess_image as well. Every
  example in source_path is read once (num_epoch=1).

  Args:
    source_path: Path to a TFRecord file containing deepvariant tf.Example
      protos.
    model: A DeepVariantModel whose preprocess_image function will be used on
      image.
    batch_size: int > 0. Size of batches to use during inference.
    num_readers: int > 0 or None. Number of parallel readers to use to read
      examples from source_path. If None, uses FLAGS.num_readers instead.

  Returns:
    A tuple of (image, encoded_variant, encoded_alt_allele_indices) TF ops.
    Image is a [height, width, channel] tensor.
    encoded_variants is a tf.string tensor containing a serialized Variant proto
    describing the variant call associated with image.
    encoded_alt_allele_indices is a tf.string tensor containing a serialized
    CallVariantsOutput.AltAlleleIndices proto containing the
    alternate alleles indices used as "alt" when constructing the image.
  """
  if not num_readers:
    num_readers = FLAGS.num_readers

  tensor_shape = tf_utils.get_shape_from_examples_path(source_path)

  def _parse_single_example(serialized_example):
    """Parses serialized example into a dictionary of de-serialized features."""
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.FixedLenFeature([], tf.string),
            'variant/encoded': tf.FixedLenFeature([], tf.string),
            # deepvariant_pb2.CallVariantsOutput.AltAlleleIndices
            'alt_allele_indices/encoded': tf.FixedLenFeature([], tf.string),
        })
    return features

  with tf.name_scope('input'):

    def _preprocess_image(features):
      """Preprocess images (decode, reshape, and apply model-specific steps)."""
      image = features['image/encoded']
      # Bypassing the reshaping and preprocessing if there is no tensor_shape.
      # Currently that could happen when the input file is empty.
      if tensor_shape:
        image = tf.reshape(tf.decode_raw(image, tf.uint8), tensor_shape)
        image = model.preprocess_image(image)
      features['image/encoded'] = image
      return features

    files = tf.gfile.Glob(io_utils.NormalizeToShardedFilePattern(source_path))
    reader_options = io_utils.make_tfrecord_options(files)
    if reader_options.compression_type == (
        tf.python_io.TFRecordCompressionType.GZIP):
      compression_type = 'GZIP'
    else:
      compression_type = None
    dataset = tf.data.TFRecordDataset(files, compression_type=compression_type)
    dataset = dataset.map(
        _parse_single_example, num_parallel_calls=FLAGS.num_readers)
    dataset = dataset.map(
        _preprocess_image, num_parallel_calls=FLAGS.num_readers)
    dataset = dataset.prefetch(10 * batch_size)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    features = iterator.get_next()
    return (features['image/encoded'], features['variant/encoded'],
            features['alt_allele_indices/encoded'])
Exemple #5
0
    def __call__(self, params):
        """Interface to get a data batch, fulfilling `input_fn` contract.

    Args:
      params: a dict containing an integer value for key 'batch_size'.

    Returns:
      the tuple (features, labels), where:
        - features is a dict of Tensor-valued input features; keys populated
          are:
            'image'
            'variant'
            'alt_allele_indices'
          and, if not PREDICT mode, also:
            'locus'

          Aside from 'image', these may be encoded specially for TPU.

        - label is the Tensor-valued prediction label; in train/eval
          mode the label value is is populated from the data source; in
          inference mode, the value is a constant empty Tensor value "()".
    """
        # See https://cloud.google.com/tpu/docs/tutorials/inception-v3-advanced
        # for some background on tuning this on TPU.

        batch_size = params['batch_size']

        compression_type = tf_utils.compression_type_of_files(self.input_files)

        # NOTE: The order of the file names returned can be non-deterministic,
        # even if shuffle is false.  See b/73959787 and the note in cl/187434282.
        # We need the shuffle flag to be able to disable reordering in EVAL mode.
        dataset = None
        for pattern in self.input_file_spec.split(','):
            one_dataset = tf.data.Dataset.list_files(
                io_utils.NormalizeToShardedFilePattern(pattern),
                shuffle=self.mode == tf.estimator.ModeKeys.TRAIN)
            dataset = dataset.concatenate(
                one_dataset) if dataset else one_dataset

        # This shuffle applies to the set of files.
        if (self.mode == tf.estimator.ModeKeys.TRAIN
                and self.initial_shuffle_buffer_size > 0):
            dataset = dataset.shuffle(self.initial_shuffle_buffer_size)

        def load_dataset(filename):
            dataset = tf.data.TFRecordDataset(
                filename,
                buffer_size=self.prefetch_dataset_buffer_size,
                compression_type=compression_type)
            return dataset

        if self.mode == tf.estimator.ModeKeys.EVAL:
            # When EVAL, avoid parallel reads for the sake of reproducibility.
            dataset = dataset.interleave(load_dataset,
                                         cycle_length=self.input_read_threads,
                                         block_length=1)
        else:
            dataset = dataset.apply(
                # parallel_interleave requires tf 1.5 or later; this is
                # necessary for good performance.
                tf.contrib.data.parallel_interleave(
                    load_dataset,
                    cycle_length=self.input_read_threads,
                    sloppy=self.sloppy))

        if self.mode == tf.estimator.ModeKeys.TRAIN:
            dataset = dataset.repeat()

        dataset = dataset.map(self.parse_tfexample,
                              num_parallel_calls=self.input_map_threads)

        dataset = dataset.prefetch(_PREFETCH_BATCHES * batch_size)

        # This shuffle applies to the set of records.
        if self.mode == tf.estimator.ModeKeys.TRAIN:
            if self.shuffle_buffer_size > 0:
                dataset = dataset.shuffle(self.shuffle_buffer_size)

        if self.mode == tf.estimator.ModeKeys.PREDICT:
            dataset = dataset.batch(batch_size)
        else:
            # N.B.: we drop the final partial batch in eval mode, to work around TPU
            # static shape limitations in the current TPUEstimator.
            dataset = dataset.batch(batch_size, drop_remainder=True)

        return dataset
Exemple #6
0
 def testNormalizeToShardedFilePattern(self, spec, expected):
   self.assertEquals(expected, io.NormalizeToShardedFilePattern(spec))