def test_max_examples(self, max_examples, batch_size):
        input_fn = data_providers.get_input_fn_from_filespec(
            input_file_spec=testdata.GOLDEN_TRAINING_EXAMPLES,
            num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES,
            name='labeled_golden',
            max_examples=max_examples,
            mode=tf.estimator.ModeKeys.TRAIN)

        n_batches_to_read = 100
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            iterator = input_fn(
                dict(batch_size=batch_size)).make_one_shot_iterator()
            next_element = iterator.get_next()

            def read_loci_in_batches():
                features, _ = sess.run(next_element)
                return features['locus']

            batches = [
                read_loci_in_batches() for _ in range(n_batches_to_read)
            ]
            unique_loci = {locus for batch in batches for locus in batch}
            # assertLen not available OSS.
            # pylint: disable=g-generic-assert
            self.assertEqual(len(unique_loci), max_examples)
Exemplo n.º 2
0
def prepare_inputs(source_path, use_tpu=False, num_readers=None):
    """Return a tf.data input_fn from the source_path.

  Args:
    source_path: Path to a TFRecord file containing deepvariant tf.Example
      protos.
    use_tpu: boolean.  Use the tpu code path.
    num_readers: int > 0 or None. Number of parallel readers to use to read
      examples from source_path. If None, uses FLAGS.num_readers instead.

  Returns:
    A tf input_fn yielding batches of image, encoded_variant,
    encoded_alt_allele_indices.

    The image is a [batch_size, height, width, channel] tensor. The
    encoded_variants is a tf.string or tpu-encoded tensor containing a
    serialized Variant proto describing the variant call associated with
    image. The encoded_alt_allele_indices is a tf.string or tpu-encoded
    tensor containing a serialized CallVariantsOutput.AltAlleleIndices proto
    containing the alternate alleles indices used as "alt" when constructing
    the image.
  """
    if not num_readers:
        num_readers = FLAGS.num_readers

    return data_providers.get_input_fn_from_filespec(
        input_file_spec=source_path,
        mode=tf.estimator.ModeKeys.PREDICT,
        use_tpu=use_tpu,
        input_read_threads=num_readers,
        debugging_true_label_mode=FLAGS.debugging_true_label_mode,
    )
Exemplo n.º 3
0
def make_golden_dataset(compressed_inputs=False,
                        mode=tf.estimator.ModeKeys.EVAL,
                        use_tpu=False):
  if compressed_inputs:
    source_path = test_utils.test_tmpfile('make_golden_dataset.tfrecord.gz')
    tfrecord.write_tfrecords(
        tfrecord.read_tfrecords(testdata.GOLDEN_TRAINING_EXAMPLES), source_path)
  else:
    source_path = testdata.GOLDEN_TRAINING_EXAMPLES
  return data_providers.get_input_fn_from_filespec(
      input_file_spec=source_path,
      num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES,
      name='labeled_golden',
      mode=mode,
      tensor_shape=None,
      use_tpu=use_tpu)