def test_max_examples(self, max_examples, batch_size): input_fn = data_providers.get_input_fn_from_filespec( input_file_spec=testdata.GOLDEN_TRAINING_EXAMPLES, num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES, name='labeled_golden', max_examples=max_examples, mode=tf.estimator.ModeKeys.TRAIN) n_batches_to_read = 100 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) iterator = input_fn( dict(batch_size=batch_size)).make_one_shot_iterator() next_element = iterator.get_next() def read_loci_in_batches(): features, _ = sess.run(next_element) return features['locus'] batches = [ read_loci_in_batches() for _ in range(n_batches_to_read) ] unique_loci = {locus for batch in batches for locus in batch} # assertLen not available OSS. # pylint: disable=g-generic-assert self.assertEqual(len(unique_loci), max_examples)
def prepare_inputs(source_path, use_tpu=False, num_readers=None): """Return a tf.data input_fn from the source_path. Args: source_path: Path to a TFRecord file containing deepvariant tf.Example protos. use_tpu: boolean. Use the tpu code path. num_readers: int > 0 or None. Number of parallel readers to use to read examples from source_path. If None, uses FLAGS.num_readers instead. Returns: A tf input_fn yielding batches of image, encoded_variant, encoded_alt_allele_indices. The image is a [batch_size, height, width, channel] tensor. The encoded_variants is a tf.string or tpu-encoded tensor containing a serialized Variant proto describing the variant call associated with image. The encoded_alt_allele_indices is a tf.string or tpu-encoded tensor containing a serialized CallVariantsOutput.AltAlleleIndices proto containing the alternate alleles indices used as "alt" when constructing the image. """ if not num_readers: num_readers = FLAGS.num_readers return data_providers.get_input_fn_from_filespec( input_file_spec=source_path, mode=tf.estimator.ModeKeys.PREDICT, use_tpu=use_tpu, input_read_threads=num_readers, debugging_true_label_mode=FLAGS.debugging_true_label_mode, )
def make_golden_dataset(compressed_inputs=False, mode=tf.estimator.ModeKeys.EVAL, use_tpu=False): if compressed_inputs: source_path = test_utils.test_tmpfile('make_golden_dataset.tfrecord.gz') tfrecord.write_tfrecords( tfrecord.read_tfrecords(testdata.GOLDEN_TRAINING_EXAMPLES), source_path) else: source_path = testdata.GOLDEN_TRAINING_EXAMPLES return data_providers.get_input_fn_from_filespec( input_file_spec=source_path, num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES, name='labeled_golden', mode=mode, tensor_shape=None, use_tpu=use_tpu)