def _get_one_example_from_examples_path(source): """Reads one record from source.""" # redacted # io_utils.read_tfrecord can read wildcard file patterns. # The source can be a comma-separated list. source_paths = source.split(',') for source_path in source_paths: files = tf.gfile.Glob( io_utils.NormalizeToShardedFilePattern(source_path)) if not files: if len(source_paths) > 1: raise ValueError( 'Cannot find matching files with the pattern "{}" in "{}"'. format(source_path, ','.join(source_paths))) else: raise ValueError( 'Cannot find matching files with the pattern "{}"'.format( source_path)) for f in files: try: return io_utils.read_tfrecords(f).next() except StopIteration: # Getting a StopIteration from one next() means source_path is empty. # Move on to the next one to try to get one example. pass return None
def test_prepare_inputs(self, filename, expand_to_file_pattern): source_path = test_utils.test_tmpfile(filename) io_utils.write_tfrecords(self.examples, source_path) if expand_to_file_pattern: # Transform foo@3 to foo-?????-of-00003. source_path = io_utils.NormalizeToShardedFilePattern(source_path) with self.test_session() as sess: sess.run(tf.local_variables_initializer()) sess.run(tf.global_variables_initializer()) ds = call_variants.prepare_inputs(source_path) _, variants, _ = data_providers.get_infer_batches(ds, model=self.model, batch_size=1) seen_variants = [] try: while True: seen_variants.extend(sess.run(variants)) except tf.errors.OutOfRangeError: pass self.assertItemsEqual(self.variants, variant_utils.decode_variants(seen_variants))
def prediction_input_fn(self, params): """Implementation of `input_fn` contract for prediction mode. Args: params: a dict containing an integer value for key 'batch_size'. Returns: the tuple (features, labels), where: - features is a dict of Tensor-valued input features; keys populated are: 'image' 'variant' 'alt_allele_indices' Aside from 'image', these may be encoded specially for TPU. """ def load_dataset(filename): dataset = tf.data.TFRecordDataset( filename, buffer_size=self.prefetch_dataset_buffer_size, compression_type=compression_type) return dataset batch_size = params['batch_size'] compression_type = tf_utils.compression_type_of_files(self.input_files) files = tf.data.Dataset.list_files( io_utils.NormalizeToShardedFilePattern(self.input_file_spec), shuffle=False, ) tf.logging.info('self.input_read_threads=%d', self.input_read_threads) dataset = files.apply( tf.contrib.data.parallel_interleave( load_dataset, cycle_length=self.input_read_threads, sloppy=self.sloppy)) tf.logging.info('self.input_map_threads=%d', self.input_map_threads) dataset = dataset.apply( tf.contrib.data.map_and_batch( self.parse_tfexample, batch_size=batch_size, num_parallel_batches=self.input_map_threads)) dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) return dataset
def prepare_inputs(source_path, model, batch_size, num_readers=None): """Prepares image and encoded_variant ops. Reads image / encoded_variant tuples from source_path, extracting the image and encoded_variant tensors from source_path. The image is decoded from its png encoding and preprocessed with model.preprocess_image as well. Every example in source_path is read once (num_epoch=1). Args: source_path: Path to a TFRecord file containing deepvariant tf.Example protos. model: A DeepVariantModel whose preprocess_image function will be used on image. batch_size: int > 0. Size of batches to use during inference. num_readers: int > 0 or None. Number of parallel readers to use to read examples from source_path. If None, uses FLAGS.num_readers instead. Returns: A tuple of (image, encoded_variant, encoded_alt_allele_indices) TF ops. Image is a [height, width, channel] tensor. encoded_variants is a tf.string tensor containing a serialized Variant proto describing the variant call associated with image. encoded_alt_allele_indices is a tf.string tensor containing a serialized CallVariantsOutput.AltAlleleIndices proto containing the alternate alleles indices used as "alt" when constructing the image. """ if not num_readers: num_readers = FLAGS.num_readers tensor_shape = tf_utils.get_shape_from_examples_path(source_path) def _parse_single_example(serialized_example): """Parses serialized example into a dictionary of de-serialized features.""" features = tf.parse_single_example( serialized_example, features={ 'image/encoded': tf.FixedLenFeature([], tf.string), 'variant/encoded': tf.FixedLenFeature([], tf.string), # deepvariant_pb2.CallVariantsOutput.AltAlleleIndices 'alt_allele_indices/encoded': tf.FixedLenFeature([], tf.string), }) return features with tf.name_scope('input'): def _preprocess_image(features): """Preprocess images (decode, reshape, and apply model-specific steps).""" image = features['image/encoded'] # Bypassing the reshaping and preprocessing if there is no tensor_shape. # Currently that could happen when the input file is empty. if tensor_shape: image = tf.reshape(tf.decode_raw(image, tf.uint8), tensor_shape) image = model.preprocess_image(image) features['image/encoded'] = image return features files = tf.gfile.Glob(io_utils.NormalizeToShardedFilePattern(source_path)) reader_options = io_utils.make_tfrecord_options(files) if reader_options.compression_type == ( tf.python_io.TFRecordCompressionType.GZIP): compression_type = 'GZIP' else: compression_type = None dataset = tf.data.TFRecordDataset(files, compression_type=compression_type) dataset = dataset.map( _parse_single_example, num_parallel_calls=FLAGS.num_readers) dataset = dataset.map( _preprocess_image, num_parallel_calls=FLAGS.num_readers) dataset = dataset.prefetch(10 * batch_size) dataset = dataset.batch(batch_size) iterator = dataset.make_one_shot_iterator() features = iterator.get_next() return (features['image/encoded'], features['variant/encoded'], features['alt_allele_indices/encoded'])
def __call__(self, params): """Interface to get a data batch, fulfilling `input_fn` contract. Args: params: a dict containing an integer value for key 'batch_size'. Returns: the tuple (features, labels), where: - features is a dict of Tensor-valued input features; keys populated are: 'image' 'variant' 'alt_allele_indices' and, if not PREDICT mode, also: 'locus' Aside from 'image', these may be encoded specially for TPU. - label is the Tensor-valued prediction label; in train/eval mode the label value is is populated from the data source; in inference mode, the value is a constant empty Tensor value "()". """ # See https://cloud.google.com/tpu/docs/tutorials/inception-v3-advanced # for some background on tuning this on TPU. batch_size = params['batch_size'] compression_type = tf_utils.compression_type_of_files(self.input_files) # NOTE: The order of the file names returned can be non-deterministic, # even if shuffle is false. See b/73959787 and the note in cl/187434282. # We need the shuffle flag to be able to disable reordering in EVAL mode. dataset = None for pattern in self.input_file_spec.split(','): one_dataset = tf.data.Dataset.list_files( io_utils.NormalizeToShardedFilePattern(pattern), shuffle=self.mode == tf.estimator.ModeKeys.TRAIN) dataset = dataset.concatenate( one_dataset) if dataset else one_dataset # This shuffle applies to the set of files. if (self.mode == tf.estimator.ModeKeys.TRAIN and self.initial_shuffle_buffer_size > 0): dataset = dataset.shuffle(self.initial_shuffle_buffer_size) def load_dataset(filename): dataset = tf.data.TFRecordDataset( filename, buffer_size=self.prefetch_dataset_buffer_size, compression_type=compression_type) return dataset if self.mode == tf.estimator.ModeKeys.EVAL: # When EVAL, avoid parallel reads for the sake of reproducibility. dataset = dataset.interleave(load_dataset, cycle_length=self.input_read_threads, block_length=1) else: dataset = dataset.apply( # parallel_interleave requires tf 1.5 or later; this is # necessary for good performance. tf.contrib.data.parallel_interleave( load_dataset, cycle_length=self.input_read_threads, sloppy=self.sloppy)) if self.mode == tf.estimator.ModeKeys.TRAIN: dataset = dataset.repeat() dataset = dataset.map(self.parse_tfexample, num_parallel_calls=self.input_map_threads) dataset = dataset.prefetch(_PREFETCH_BATCHES * batch_size) # This shuffle applies to the set of records. if self.mode == tf.estimator.ModeKeys.TRAIN: if self.shuffle_buffer_size > 0: dataset = dataset.shuffle(self.shuffle_buffer_size) if self.mode == tf.estimator.ModeKeys.PREDICT: dataset = dataset.batch(batch_size) else: # N.B.: we drop the final partial batch in eval mode, to work around TPU # static shape limitations in the current TPUEstimator. dataset = dataset.batch(batch_size, drop_remainder=True) return dataset
def testNormalizeToShardedFilePattern(self, spec, expected): self.assertEquals(expected, io.NormalizeToShardedFilePattern(spec))