def test_prepare_inputs(self, filename_to_write, file_string_input): source_path = test_utils.test_tmpfile(filename_to_write) io_utils.write_tfrecords(self.examples, source_path) # file_string_input could be a comma-separated list. Add the prefix to all # of them, and join it back to a string. file_string_input = ','.join( [test_utils.test_tmpfile(f) for f in file_string_input.split(',')]) with self.test_session() as sess: sess.run(tf.local_variables_initializer()) sess.run(tf.global_variables_initializer()) ds = call_variants.prepare_inputs(file_string_input) _, variants, _ = data_providers.get_infer_batches(ds, model=self.model, batch_size=1) seen_variants = [] try: while True: seen_variants.extend(sess.run(variants)) except tf.errors.OutOfRangeError: pass self.assertItemsEqual(self.variants, variant_utils.decode_variants(seen_variants))
def testGetNoneShapeFromEmptyExamplesPath(self, file_name_to_write, tfrecord_path_to_match): output_file = test_utils.test_tmpfile(file_name_to_write) io_utils.write_tfrecords([], output_file) self.assertIsNone( tf_utils.get_shape_from_examples_path( test_utils.test_tmpfile(tfrecord_path_to_match)))
def test_prepare_inputs(self, filename, expand_to_file_pattern): source_path = test_utils.test_tmpfile(filename) io_utils.write_tfrecords(self.examples, source_path) if expand_to_file_pattern: # Transform foo@3 to foo-?????-of-00003. source_path = io_utils.NormalizeToShardedFilePattern(source_path) with self.test_session() as sess: sess.run(tf.local_variables_initializer()) sess.run(tf.global_variables_initializer()) ds = call_variants.prepare_inputs(source_path) _, variants, _ = data_providers.get_infer_batches(ds, model=self.model, batch_size=1) seen_variants = [] try: while True: seen_variants.extend(sess.run(variants)) except tf.errors.OutOfRangeError: pass self.assertItemsEqual(self.variants, variant_utils.decode_variants(seen_variants))
def test_call_variants_with_empty_input(self): source_path = test_utils.test_tmpfile('empty.tfrecord') io_utils.write_tfrecords([], source_path) # Make sure that prepare_inputs don't crash on empty input. call_variants.prepare_inputs(source_path, modeling.get_model('random_guess'), batch_size=1)
def testGetShapeFromExamplesPath(self, file_name_to_write, tfrecord_path_to_match): example = example_pb2.Example() valid_shape = [1, 2, 3] example.features.feature['image/shape'].int64_list.value.extend(valid_shape) output_file = test_utils.test_tmpfile(file_name_to_write) io_utils.write_tfrecords([example], output_file) tf_utils.get_shape_from_examples_path( test_utils.test_tmpfile(tfrecord_path_to_match))
def test_call_end2end_with_empty_shards(self): # Get only up to 10 examples. examples = list( io_utils.read_tfrecords( testdata.GOLDEN_CALLING_EXAMPLES, max_records=10)) # Write to 15 shards, which means there will be multiple empty shards. source_path = test_utils.test_tmpfile('sharded@{}'.format(15)) io_utils.write_tfrecords(examples, source_path) self.assertCallVariantsEmitsNRecordsForRandomGuess(source_path, len(examples))
def test_call_end2end_with_empty_shards(self): # Get only up to 10 examples. examples = list( io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES, max_records=10)) # Write to 15 shards, which means there will be multiple empty shards. source_path = test_utils.test_tmpfile('sharded@{}'.format(15)) io_utils.write_tfrecords(examples, source_path) self.assertCallVariantsEmitsNRecordsForRandomGuess( source_path, len(examples))
def make_golden_dataset(compressed_inputs=False): if compressed_inputs: source_path = test_utils.test_tmpfile( 'golden.postprocess_single_site_input.tfrecord.gz') io_utils.write_tfrecords( io_utils.read_tfrecords( testdata.GOLDEN_POSTPROCESS_INPUT, proto=deepvariant_pb2.CallVariantsOutput), source_path) else: source_path = testdata.GOLDEN_POSTPROCESS_INPUT return source_path
def make_golden_dataset(compressed_inputs=False): if compressed_inputs: source_path = test_utils.test_tmpfile('make_golden_dataset.tfrecord.gz') io_utils.write_tfrecords( io_utils.read_tfrecords(testdata.GOLDEN_TRAINING_EXAMPLES), source_path) else: source_path = testdata.GOLDEN_TRAINING_EXAMPLES return data_providers.DeepVariantDataSet( name='labeled_golden', source=source_path, num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES)
def test_call_end2end_empty_first_shard(self): # Get only up to 10 examples. examples = list( io_utils.read_tfrecords( testdata.GOLDEN_CALLING_EXAMPLES, max_records=10)) empty_first_file = test_utils.test_tmpfile('empty_1st_shard-00000-of-00002') io_utils.write_tfrecords([], empty_first_file) second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002') io_utils.write_tfrecords(examples, second_file) self.assertCallVariantsEmitsNRecordsForRandomGuess( test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))
def test_call_end2end_empty_first_shard(self): # Get only up to 10 examples. examples = list( io_utils.read_tfrecords( testdata.GOLDEN_CALLING_EXAMPLES, max_records=10)) empty_first_file = test_utils.test_tmpfile('empty_1st_shard-00000-of-00002') io_utils.write_tfrecords([], empty_first_file) second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002') io_utils.write_tfrecords(examples, second_file) self.assertCallVariantsEmitsNRecordsForRandomGuess( test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))
def make_golden_dataset(compressed_inputs=False): if compressed_inputs: source_path = test_utils.test_tmpfile( 'golden.postprocess_single_site_input.tfrecord.gz') io_utils.write_tfrecords( io_utils.read_tfrecords( testdata.GOLDEN_POSTPROCESS_INPUT, proto=deepvariant_pb2.CallVariantsOutput), source_path) else: source_path = testdata.GOLDEN_POSTPROCESS_INPUT return source_path
def test_call_variants_with_no_shape(self, model): # Read one good record from a valid file. example = next(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES)) # Remove image/shape. del example.features.feature['image/shape'] source_path = test_utils.test_tmpfile('make_examples_out_noshape.tfrecord') io_utils.write_tfrecords([example], source_path) with self.assertRaisesRegexp( ValueError, 'Invalid image/shape: we expect to find an image/shape ' 'field with length 3.'): call_variants.prepare_inputs(source_path, model, batch_size=1)
def test_get_shape_from_examples_path(self, file_name_to_write, tfrecord_path_to_match): example = example_pb2.Example() valid_shape = [1, 2, 3] example.features.feature['image/shape'].int64_list.value.extend(valid_shape) output_file = test_utils.test_tmpfile(file_name_to_write) io_utils.write_tfrecords([example], output_file) ds = data_providers.DeepVariantDataSet( name='test_shape', source=test_utils.test_tmpfile(tfrecord_path_to_match), num_examples=1) self.assertEqual(valid_shape, ds.tensor_shape)
def test_call_variants_with_no_shape(self, model): # Read one good record from a valid file. example = next(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES)) # Remove image/shape. del example.features.feature['image/shape'] source_path = test_utils.test_tmpfile('make_examples_out_noshape.tfrecord') io_utils.write_tfrecords([example], source_path) with self.assertRaisesRegexp( ValueError, 'Invalid image/shape: we expect to find an image/shape ' 'field with length 3.'): ds = call_variants.prepare_inputs(source_path) _ = list(data_providers.get_infer_batches(ds, model=model, batch_size=1))
def test_reading_empty_input_should_raise_error(self): empty_shard_one = test_utils.test_tmpfile( 'no_records.tfrecord-00000-of-00002') empty_shard_two = test_utils.test_tmpfile( 'no_records.tfrecord-00001-of-00002') io_utils.write_tfrecords([], empty_shard_one) io_utils.write_tfrecords([], empty_shard_two) FLAGS.infile = test_utils.test_tmpfile('no_records.tfrecord@2') FLAGS.ref = testdata.CHR20_FASTA FLAGS.outfile = test_utils.test_tmpfile('no_records.vcf') with self.assertRaisesRegexp(ValueError, 'Cannot find any records in'): postprocess_variants.main(['postprocess_variants.py'])
def test_get_shape_from_examples_path(self, file_name_to_write, tfrecord_path_to_match): example = example_pb2.Example() valid_shape = [1, 2, 3] example.features.feature['image/shape'].int64_list.value.extend( valid_shape) output_file = test_utils.test_tmpfile(file_name_to_write) io_utils.write_tfrecords([example], output_file) ds = data_providers.DeepVariantDataSet( name='test_shape', source=test_utils.test_tmpfile(tfrecord_path_to_match), num_examples=1) self.assertEqual(valid_shape, ds.tensor_shape)
def make_golden_dataset(compressed_inputs=False): if compressed_inputs: source_path = test_utils.test_tmpfile( 'make_golden_dataset.tfrecord.gz') io_utils.write_tfrecords( io_utils.read_tfrecords(testdata.GOLDEN_TRAINING_EXAMPLES), source_path) else: source_path = testdata.GOLDEN_TRAINING_EXAMPLES return data_providers.DeepVariantDataSet( name='labeled_golden', source=source_path, num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES)
def test_reading_sharded_input_with_empty_shards_does_not_crash(self): valid_variants = io_utils.read_tfrecords( testdata.GOLDEN_POSTPROCESS_INPUT, proto=deepvariant_pb2.CallVariantsOutput) empty_shard_one = test_utils.test_tmpfile( 'reading_empty_shard.tfrecord-00000-of-00002') none_empty_shard_two = test_utils.test_tmpfile( 'reading_empty_shard.tfrecord-00001-of-00002') io_utils.write_tfrecords([], empty_shard_one) io_utils.write_tfrecords(valid_variants, none_empty_shard_two) FLAGS.infile = test_utils.test_tmpfile('reading_empty_shard.tfrecord@2') FLAGS.ref = testdata.CHR20_FASTA FLAGS.outfile = test_utils.test_tmpfile('calls_reading_empty_shard.vcf') postprocess_variants.main(['postprocess_variants.py'])
def test_reading_sharded_dataset(self, compressed_inputs): golden_dataset = make_golden_dataset(compressed_inputs) n_shards = 3 sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards)) io_utils.write_tfrecords( io_utils.read_tfrecords(golden_dataset.source), sharded_path) config_file = _test_dataset_config( 'test_sharded.pbtxt', name='sharded_test', tfrecord_path=sharded_path, num_examples=golden_dataset.num_examples) self.assertDataSetExamplesMatchExpected( data_providers.get_dataset(config_file).get_slim_dataset(), golden_dataset)
def test_reading_sharded_dataset(self, compressed_inputs): golden_dataset = make_golden_dataset(compressed_inputs) n_shards = 3 sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards)) io_utils.write_tfrecords( io_utils.read_tfrecords(golden_dataset.source), sharded_path) config_file = _test_dataset_config( 'test_sharded.pbtxt', name='sharded_test', tfrecord_path=sharded_path, num_examples=golden_dataset.num_examples) self.assertDataSetExamplesMatchExpected( data_providers.get_dataset(config_file).get_slim_dataset(), golden_dataset)
def test_call_variants_with_empty_input(self): source_path = test_utils.test_tmpfile('empty.tfrecord') io_utils.write_tfrecords([], source_path) # Make sure that prepare_inputs don't crash on empty input. ds = call_variants.prepare_inputs(source_path) m = modeling.get_model('random_guess') # The API specifies that OutOfRangeError is thrown in this case. batches = list(data_providers.get_infer_batches(ds, model=m, batch_size=1)) with self.test_session() as sess: sess.run(tf.local_variables_initializer()) sess.run(tf.global_variables_initializer()) try: _ = sess.run(batches) except tf.errors.OutOfRangeError: pass
def test_call_variants_with_invalid_format(self, model, bad_format): # Read one good record from a valid file. example = next(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES)) # Overwrite the image/format field to be an invalid value # (anything but 'raw'). example.features.feature['image/format'].bytes_list.value[0] = bad_format source_path = test_utils.test_tmpfile('make_examples_output.tfrecord') io_utils.write_tfrecords([example], source_path) outfile = test_utils.test_tmpfile('call_variants_invalid_format.tfrecord') with self.assertRaises(ValueError): call_variants.call_variants( examples_filename=source_path, checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST, model=model, output_file=outfile, batch_size=1, max_batches=1)
def make_golden_dataset(compressed_inputs=False, mode=tf.estimator.ModeKeys.EVAL, use_tpu=False): if compressed_inputs: source_path = test_utils.test_tmpfile( 'make_golden_dataset.tfrecord.gz') io_utils.write_tfrecords( io_utils.read_tfrecords(testdata.GOLDEN_TRAINING_EXAMPLES), source_path) else: source_path = testdata.GOLDEN_TRAINING_EXAMPLES return data_providers.get_input_fn_from_filespec( input_file_spec=source_path, num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES, name='labeled_golden', mode=mode, tensor_shape=None, use_tpu=use_tpu)
def test_call_variants_with_invalid_format(self, model, bad_format): # Read one good record from a valid file. example = next(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES)) # Overwrite the image/format field to be an invalid value # (anything but 'raw'). example.features.feature['image/format'].bytes_list.value[0] = bad_format source_path = test_utils.test_tmpfile('make_examples_output.tfrecord') io_utils.write_tfrecords([example], source_path) outfile = test_utils.test_tmpfile('call_variants_invalid_format.tfrecord') with self.assertRaises(ValueError): call_variants.call_variants( examples_filename=source_path, checkpoint_path=_LEAVE_MODEL_UNINITIALIZED, model=model, output_file=outfile, batch_size=1, max_batches=1, use_tpu=FLAGS.use_tpu)
def _call_end2end_helper(self, examples_path, model, shard_inputs): examples = list(io_utils.read_tfrecords(examples_path)) if shard_inputs: # Create a sharded version of our golden examples. source_path = test_utils.test_tmpfile('sharded@{}'.format(3)) io_utils.write_tfrecords(examples, source_path) else: source_path = examples_path # If we point the test at a headless server, it will often be 2x2, # which has 8 replicas. Otherwise a smaller batch size is fine. if FLAGS.use_tpu: batch_size = 8 else: batch_size = 4 if model.name == 'random_guess': # For the random guess model we can run everything. max_batches = None else: # For all other models we only run a single batch for inference. max_batches = 1 outfile = test_utils.test_tmpfile('call_variants.tfrecord') call_variants.call_variants( examples_filename=source_path, checkpoint_path=_LEAVE_MODEL_UNINITIALIZED, model=model, output_file=outfile, batch_size=batch_size, max_batches=max_batches, master='', use_tpu=FLAGS.use_tpu, ) call_variants_outputs = list( io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput)) return call_variants_outputs, examples, batch_size, max_batches
def test_prepare_inputs(self, filename, expand_to_file_pattern): source_path = test_utils.test_tmpfile(filename) io_utils.write_tfrecords(self.examples, source_path) if expand_to_file_pattern: # Transform foo@3 to foo-?????-of-00003. source_path = io_utils.NormalizeToShardedFilePattern(source_path) with self.test_session() as sess: _, variants, _ = call_variants.prepare_inputs( source_path, self.model, batch_size=1) sess.run(tf.local_variables_initializer()) sess.run(tf.global_variables_initializer()) seen_variants = [] try: while True: seen_variants.extend(sess.run(variants)) except tf.errors.OutOfRangeError: pass self.assertItemsEqual(self.variants, variant_utils.decode_variants(seen_variants))
def test_reading_sharded_dataset(self, compressed_inputs, use_tpu): golden_dataset = make_golden_dataset(compressed_inputs, use_tpu=use_tpu) n_shards = 3 sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards)) io_utils.write_tfrecords( io_utils.read_tfrecords(golden_dataset.input_file_spec), sharded_path) config_file = _test_dataset_config( 'test_sharded.pbtxt', name='sharded_test', tfrecord_path=sharded_path, num_examples=golden_dataset.num_examples) self.assertTfDataSetExamplesMatchExpected( data_providers.get_input_fn_from_dataset( config_file, mode=tf.estimator.ModeKeys.EVAL), golden_dataset, # workaround_list_files is needed because wildcards, and so sharded # files, are nondeterministicly ordered (for now). workaround_list_files=True, )
def test_call_end2end_zero_record_file(self): zero_record_file = test_utils.test_tmpfile('zero_record_file') io_utils.write_tfrecords([], zero_record_file) self.assertCallVariantsEmitsNRecordsForRandomGuess( test_utils.test_tmpfile('zero_record_file'), 0)
def write_test_protos(self, filename): protos = [reference_pb2.ContigInfo(name=str(i)) for i in range(10)] path = test_utils.test_tmpfile(filename) io.write_tfrecords(protos, path) return protos, path
def test_call_end2end_zero_record_file_for_inception_v3(self): zero_record_file = test_utils.test_tmpfile('zero_record_file') io_utils.write_tfrecords([], zero_record_file) self.assertCallVariantsEmitsNRecordsForInceptionV3( test_utils.test_tmpfile('zero_record_file'), 0)
def test_call_end2end_zero_record_file(self): zero_record_file = test_utils.test_tmpfile('zero_record_file') io_utils.write_tfrecords([], zero_record_file) self.assertCallVariantsEmitsNRecordsForRandomGuess( test_utils.test_tmpfile('zero_record_file'), 0)
def test_call_end2end(self, model, shard_inputs, include_debug_info): FLAGS.include_debug_info = include_debug_info examples = list(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES)) if shard_inputs: # Create a sharded version of our golden examples. source_path = test_utils.test_tmpfile('sharded@{}'.format(3)) io_utils.write_tfrecords(examples, source_path) else: source_path = testdata.GOLDEN_CALLING_EXAMPLES batch_size = 4 if model.name == 'random_guess': # For the random guess model we can run everything. max_batches = None else: # For all other models we only run a single batch for inference. max_batches = 1 outfile = test_utils.test_tmpfile('call_variants.tfrecord') call_variants.call_variants( examples_filename=source_path, checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST, model=model, output_file=outfile, batch_size=batch_size, max_batches=max_batches) call_variants_outputs = list( io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput)) # Check that we have the right number of output protos. self.assertEqual( len(call_variants_outputs), batch_size * max_batches if max_batches else len(examples)) # Check that our CallVariantsOutput (CVO) have the following critical # properties: # - we have one CVO for each example we processed. # - the variant in the CVO is exactly what was in the example. # - the alt_allele_indices of the CVO match those of its corresponding # example. # - there are 3 genotype probabilities and these are between 0.0 and 1.0. # We can only do this test when processing all of the variants (max_batches # is None), since we processed all of the examples with that model. if max_batches is None: self.assertItemsEqual([cvo.variant for cvo in call_variants_outputs], [tf_utils.example_variant(ex) for ex in examples]) # Check the CVO debug_info: not filled if include_debug_info is False; # else, filled by logic based on CVO. if not include_debug_info: for cvo in call_variants_outputs: self.assertEqual(cvo.debug_info, deepvariant_pb2.CallVariantsOutput.DebugInfo()) else: for cvo in call_variants_outputs: self.assertEqual(cvo.debug_info.has_insertion, variant_utils.has_insertion(cvo.variant)) self.assertEqual(cvo.debug_info.has_deletion, variant_utils.has_deletion(cvo.variant)) self.assertEqual(cvo.debug_info.is_snp, variant_utils.is_snp( cvo.variant)) self.assertEqual(cvo.debug_info.predicted_label, np.argmax(cvo.genotype_probabilities)) def example_matches_call_variants_output(example, call_variants_output): return (tf_utils.example_variant(example) == call_variants_output.variant and tf_utils.example_alt_alleles_indices( example) == call_variants_output.alt_allele_indices.indices) for call_variants_output in call_variants_outputs: # Find all matching examples. matches = [ ex for ex in examples if example_matches_call_variants_output(ex, call_variants_output) ] # We should have exactly one match. self.assertEqual(len(matches), 1) example = matches[0] # Check that we've faithfully copied in the alt alleles (though currently # as implemented we find our example using this information so it cannot # fail). Included here in case that changes in the future. self.assertEqual( list(tf_utils.example_alt_alleles_indices(example)), list(call_variants_output.alt_allele_indices.indices)) # We should have exactly three genotype probabilities (assuming our # ploidy == 2). self.assertEqual(len(call_variants_output.genotype_probabilities), 3) # These are probabilities so they should be between 0 and 1. self.assertTrue( 0 <= gp <= 1 for gp in call_variants_output.genotype_probabilities)
def test_call_end2end(self, model, shard_inputs, include_debug_info): FLAGS.include_debug_info = include_debug_info examples = list( io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES)) if shard_inputs: # Create a sharded version of our golden examples. source_path = test_utils.test_tmpfile('sharded@{}'.format(3)) io_utils.write_tfrecords(examples, source_path) else: source_path = testdata.GOLDEN_CALLING_EXAMPLES batch_size = 4 if model.name == 'random_guess': # For the random guess model we can run everything. max_batches = None else: # For all other models we only run a single batch for inference. max_batches = 1 outfile = test_utils.test_tmpfile('call_variants.tfrecord') call_variants.call_variants( examples_filename=source_path, checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST, model=model, output_file=outfile, batch_size=batch_size, max_batches=max_batches) call_variants_outputs = list( io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput)) # Check that we have the right number of output protos. self.assertEqual( len(call_variants_outputs), batch_size * max_batches if max_batches else len(examples)) # Check that our CallVariantsOutput (CVO) have the following critical # properties: # - we have one CVO for each example we processed. # - the variant in the CVO is exactly what was in the example. # - the alt_allele_indices of the CVO match those of its corresponding # example. # - there are 3 genotype probabilities and these are between 0.0 and 1.0. # We can only do this test when processing all of the variants (max_batches # is None), since we processed all of the examples with that model. if max_batches is None: self.assertItemsEqual( [cvo.variant for cvo in call_variants_outputs], [tf_utils.example_variant(ex) for ex in examples]) # Check the CVO debug_info: not filled if include_debug_info is False; # else, filled by logic based on CVO. if not include_debug_info: for cvo in call_variants_outputs: self.assertEqual( cvo.debug_info, deepvariant_pb2.CallVariantsOutput.DebugInfo()) else: for cvo in call_variants_outputs: self.assertEqual(cvo.debug_info.has_insertion, variant_utils.has_insertion(cvo.variant)) self.assertEqual(cvo.debug_info.has_deletion, variant_utils.has_deletion(cvo.variant)) self.assertEqual(cvo.debug_info.is_snp, variant_utils.is_snp(cvo.variant)) self.assertEqual(cvo.debug_info.predicted_label, np.argmax(cvo.genotype_probabilities)) def example_matches_call_variants_output(example, call_variants_output): return (tf_utils.example_variant(example) == call_variants_output.variant and tf_utils.example_alt_alleles_indices(example) == call_variants_output.alt_allele_indices.indices) for call_variants_output in call_variants_outputs: # Find all matching examples. matches = [ ex for ex in examples if example_matches_call_variants_output( ex, call_variants_output) ] # We should have exactly one match. self.assertEqual(len(matches), 1) example = matches[0] # Check that we've faithfully copied in the alt alleles (though currently # as implemented we find our example using this information so it cannot # fail). Included here in case that changes in the future. self.assertEqual( list(tf_utils.example_alt_alleles_indices(example)), list(call_variants_output.alt_allele_indices.indices)) # We should have exactly three genotype probabilities (assuming our # ploidy == 2). self.assertEqual(len(call_variants_output.genotype_probabilities), 3) # These are probabilities so they should be between 0 and 1. self.assertTrue( 0 <= gp <= 1 for gp in call_variants_output.genotype_probabilities)
def call_variants(examples_filename, checkpoint_path, model, output_file, execution_hardware='auto', batch_size=16, max_batches=None): """Main driver of call_variants.""" # Read a single TFExample to make sure we're not loading an older version. example_format = tf_utils.get_format_from_examples_path(examples_filename) if example_format is None: logging.warning('Unable to read any records from %s. Output will contain ' 'zero records.', examples_filename) io_utils.write_tfrecords([], output_file) elif example_format != 'raw': raise ValueError('The TF examples in {} has image/format \'{}\' ' '(expected \'raw\') which means you might need to rerun ' 'make_examples to generate the examples again.'.format( examples_filename, example_format)) if execution_hardware not in _ALLOW_EXECUTION_HARDWARE: raise ValueError( 'Unexpected execution_hardware={} value. Allowed values are {}'.format( execution_hardware, ','.join(_ALLOW_EXECUTION_HARDWARE))) with tf.Graph().as_default(): images, encoded_variants, encoded_alt_allele_indices = prepare_inputs( examples_filename, model, batch_size, FLAGS.num_readers) # Create our model and extract the predictions from the model endpoints. predictions = model.create(images, 3, is_training=False)['Predictions'] # The op for initializing the variables. init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) device_count = {'GPU': 0, 'TPU': 0} if execution_hardware == 'cpu' else {} config = tf.ConfigProto(device_count=device_count) with tf.Session(config=config) as sess: sess.run(init_op) # Initial the model from the provided checkpoint using our session. logging.info('Initializing model from %s', checkpoint_path) model.initialize_from_checkpoint(checkpoint_path, 3, False)(sess) if execution_hardware == 'accelerator': if not any(dev.device_type != 'CPU' for dev in sess.list_devices()): raise ExecutionHardwareError( 'execution_hardware is set to accelerator, but no accelerator ' 'was found') logging.info('Writing calls to %s', output_file) writer, _ = io_utils.make_proto_writer(output_file) with writer: start_time = time.time() try: n_batches = 0 n_examples = 0 while max_batches is None or n_batches < max_batches: n_called = call_batch(sess, writer, encoded_variants, encoded_alt_allele_indices, predictions) duration = time.time() - start_time n_batches += 1 n_examples += n_called logging.info( ('Processed %s examples in %s batches [%.2f sec per 100]'), n_examples, n_batches, (100 * duration) / n_examples) except tf.errors.OutOfRangeError: logging.info('Done evaluating variants')
def call_variants(examples_filename, checkpoint_path, model, output_file, execution_hardware='auto', batch_size=16, max_batches=None, use_tpu=False, master=''): """Main driver of call_variants.""" if FLAGS.kmp_blocktime: os.environ['KMP_BLOCKTIME'] = FLAGS.kmp_blocktime logging.info('Set KMP_BLOCKTIME to %s', os.environ['KMP_BLOCKTIME']) # Read a single TFExample to make sure we're not loading an older version. example_format = tf_utils.get_format_from_examples_path(examples_filename) if example_format is None: logging.warning( 'Unable to read any records from %s. Output will contain ' 'zero records.', examples_filename) io_utils.write_tfrecords([], output_file) return elif example_format != 'raw': raise ValueError( 'The TF examples in {} has image/format \'{}\' ' '(expected \'raw\') which means you might need to rerun ' 'make_examples to generate the examples again.'.format( examples_filename, example_format)) # Check accelerator status. if execution_hardware not in _ALLOW_EXECUTION_HARDWARE: raise ValueError( 'Unexpected execution_hardware={} value. Allowed values are {}'. format(execution_hardware, ','.join(_ALLOW_EXECUTION_HARDWARE))) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) device_count = {'GPU': 0, 'TPU': 0} if execution_hardware == 'cpu' else {} config = tf.ConfigProto(device_count=device_count) with tf.Session(config=config) as sess: sess.run(init_op) if execution_hardware == 'accelerator': if not any(dev.device_type != 'CPU' for dev in sess.list_devices()): raise ExecutionHardwareError( 'execution_hardware is set to accelerator, but no accelerator ' 'was found') # redacted # sess.list_devices here doesn't return the correct answer. That can only # work later, after the device (on the other VM) has been initialized, # which is generally not yet. # Prepare input stream and estimator. tf_dataset = prepare_inputs(source_path=examples_filename, use_tpu=use_tpu) estimator = model.make_estimator( batch_size=batch_size, master=master, use_tpu=use_tpu, ) # Instantiate the prediction "stream", and select the EMA values from # the model. if checkpoint_path is None: # Unit tests use this branch. predict_hooks = [] else: predict_hooks = [ h(checkpoint_path) for h in model.session_predict_hooks() ] predictions = iter( estimator.predict(input_fn=tf_dataset, checkpoint_path=checkpoint_path, hooks=predict_hooks)) # Consume predictions one at a time and write them to output_file. logging.info('Writing calls to %s', output_file) writer, _ = io_utils.make_proto_writer(output_file) with writer: start_time = time.time() n_examples, n_batches = 0, 0 while max_batches is None or n_batches <= max_batches: try: prediction = next(predictions) except (StopIteration, tf.errors.OutOfRangeError): break write_variant_call(writer, prediction, use_tpu) n_examples += 1 n_batches = n_examples // batch_size + 1 duration = time.time() - start_time logging.log_every_n( logging.INFO, ('Processed %s examples in %s batches [%.3f sec per 100]'), _LOG_EVERY_N, n_examples, n_batches, (100 * duration) / n_examples) logging.info('Done evaluating variants')
def call_variants(examples_filename, checkpoint_path, model, output_file, execution_hardware='auto', batch_size=16, max_batches=None): """Main driver of call_variants.""" # Read a single TFExample to make sure we're not loading an older version. example_format = tf_utils.get_format_from_examples_path(examples_filename) if example_format is None: logging.warning('Unable to read any records from %s. Output will contain ' 'zero records.', examples_filename) io_utils.write_tfrecords([], output_file) return elif example_format != 'raw': raise ValueError('The TF examples in {} has image/format \'{}\' ' '(expected \'raw\') which means you might need to rerun ' 'make_examples to generate the examples again.'.format( examples_filename, example_format)) if execution_hardware not in _ALLOW_EXECUTION_HARDWARE: raise ValueError( 'Unexpected execution_hardware={} value. Allowed values are {}'.format( execution_hardware, ','.join(_ALLOW_EXECUTION_HARDWARE))) with tf.Graph().as_default(): images, encoded_variants, encoded_alt_allele_indices = prepare_inputs( examples_filename, model, batch_size, FLAGS.num_readers) # Create our model and extract the predictions from the model endpoints. predictions = model.create(images, 2, is_training=False)['Predictions'] # The op for initializing the variables. init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) device_count = {'GPU': 0, 'TPU': 0} if execution_hardware == 'cpu' else {} config = tf.ConfigProto(device_count=device_count) with tf.Session(config=config) as sess: sess.run(init_op) # Initial the model from the provided checkpoint using our session. logging.info('Initializing model from %s', checkpoint_path) model.initialize_from_checkpoint(checkpoint_path, 2, False)(sess) if execution_hardware == 'accelerator': if not any(dev.device_type != 'CPU' for dev in sess.list_devices()): raise ExecutionHardwareError( 'execution_hardware is set to accelerator, but no accelerator ' 'was found') logging.info('Writing calls to %s', output_file) writer, _ = io_utils.make_proto_writer(output_file) with writer: start_time = time.time() try: n_batches = 0 n_examples = 0 while max_batches is None or n_batches < max_batches: n_called = call_batch(sess, writer, encoded_variants, encoded_alt_allele_indices, predictions) duration = time.time() - start_time n_batches += 1 n_examples += n_called logging.info( ('Processed %s examples in %s batches [%.2f sec per 100]'), n_examples, n_batches, (100 * duration) / n_examples) except tf.errors.OutOfRangeError: logging.info('Done evaluating variants')
def test_call_variants_with_empty_input(self): source_path = test_utils.test_tmpfile('empty.tfrecord') io_utils.write_tfrecords([], source_path) # Make sure that prepare_inputs don't crash on empty input. call_variants.prepare_inputs( source_path, modeling.get_model('random_guess'), batch_size=1)