コード例 #1
0
    def prediction_input_fn(self, params):
        """Implementation of `input_fn` contract for prediction mode.

    Args:
      params: a dict containing an integer value for key 'batch_size'.

    Returns:
      the tuple (features, labels), where:
        - features is a dict of Tensor-valued input features; keys populated
          are:
            'image'
            'variant'
            'alt_allele_indices'

          Aside from 'image', these may be encoded specially for TPU.
    """
        def load_dataset(filename):
            dataset = tf.data.TFRecordDataset(
                filename,
                buffer_size=self.prefetch_dataset_buffer_size,
                compression_type=compression_type)
            return dataset

        batch_size = params['batch_size']
        compression_type = tf_utils.compression_type_of_files(self.input_files)
        files = tf.data.Dataset.list_files(
            sharded_file_utils.normalize_to_sharded_file_pattern(
                self.input_file_spec),
            shuffle=False,
        )
        logging.vlog(
            3, 'self.input_read_threads={}'.format(self.input_read_threads))
        dataset = files.apply(
            tf.data.experimental.parallel_interleave(
                load_dataset,
                cycle_length=self.input_read_threads,
                sloppy=self.sloppy))
        logging.vlog(
            3, 'self.input_map_threads={}'.format(self.input_map_threads))
        dataset = dataset.apply(
            tf.data.experimental.map_and_batch(
                self.parse_tfexample,
                batch_size=batch_size,
                num_parallel_batches=self.input_map_threads))
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        return dataset
コード例 #2
0
 def testCompressionTypeOfFiles(self):
     self.assertEqual(
         'GZIP',
         tf_utils.compression_type_of_files(['/tmp/foo.tfrecord.gz']))
     self.assertEqual(
         None, tf_utils.compression_type_of_files(['/tmp/foo.tfrecord']))
コード例 #3
0
  def __call__(self, params):
    """Interface to get a data batch, fulfilling `input_fn` contract.

    Args:
      params: a dict containing an integer value for key 'batch_size'.

    Returns:
      the tuple (features, labels), where:
        - features is a dict of Tensor-valued input features; keys populated
          are:
            'image'
            'variant'
            'alt_allele_indices'
          and, if not PREDICT mode, also:
            'locus'

          Aside from 'image', these may be encoded specially for TPU.

        - label is the Tensor-valued prediction label; in train/eval
          mode the label value is is populated from the data source; in
          inference mode, the value is a constant empty Tensor value "()".
    """
    # See https://cloud.google.com/tpu/docs/tutorials/inception-v3-advanced
    # for some background on tuning this on TPU.

    # TPU optimized implementation for prediction mode
    if self.mode == tf.estimator.ModeKeys.PREDICT:
      return self.prediction_input_fn(params)

    # Optimized following:
    #   https://www.tensorflow.org/guide/performance/datasets
    # using the information available from xprof.
    def load_dataset(filename):
      dataset = tf.data.TFRecordDataset(
          filename,
          buffer_size=self.prefetch_dataset_buffer_size,
          compression_type=compression_type)
      return dataset

    batch_size = params['batch_size']
    compression_type = tf_utils.compression_type_of_files(self.input_files)

    # NOTE: The order of the file names returned can be non-deterministic,
    # even if shuffle is false.  See b/73959787 and the note in cl/187434282.
    # We need the shuffle flag to be able to disable reordering in EVAL mode.
    dataset = None
    for pattern in self.input_file_spec.split(','):
      one_dataset = tf.data.Dataset.list_files(
          sharded_file_utils.normalize_to_sharded_file_pattern(pattern),
          shuffle=self.mode == tf.estimator.ModeKeys.TRAIN)
      dataset = dataset.concatenate(one_dataset) if dataset else one_dataset

    # This shuffle applies to the set of files.
    # redacted
    if (self.mode == tf.estimator.ModeKeys.TRAIN and
        self.initial_shuffle_buffer_size > 0):
      dataset = dataset.shuffle(self.initial_shuffle_buffer_size)

    if self.mode == tf.estimator.ModeKeys.EVAL:
      # When EVAL, avoid parallel reads for the sake of reproducibility.
      dataset = dataset.interleave(
          load_dataset, cycle_length=self.input_read_threads, block_length=1)
    else:
      dataset = dataset.apply(
          # parallel_interleave requires tf 1.5 or later; this is
          # necessary for good performance.
          tf.contrib.data.parallel_interleave(
              load_dataset,
              cycle_length=self.input_read_threads,
              sloppy=self.sloppy))

    if self.max_examples is not None:
      dataset = dataset.take(self.max_examples)

    if self.mode == tf.estimator.ModeKeys.TRAIN:
      dataset = dataset.repeat()

    # This shuffle applies to the set of records.
    if self.mode == tf.estimator.ModeKeys.TRAIN:
      if self.shuffle_buffer_size > 0:
        dataset = dataset.shuffle(self.shuffle_buffer_size)

    dataset = dataset.apply(
        tf.data.experimental.map_and_batch(
            map_func=self.parse_tfexample,
            batch_size=batch_size,
            num_parallel_batches=_PREFETCH_BATCHES,
            drop_remainder=True))

    dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)

    return dataset