def testGetNoneShapeFromEmptyExamplesPath(self, file_name_to_write, tfrecord_path_to_match): output_file = test_utils.test_tmpfile(file_name_to_write) io_utils.write_tfrecords([], output_file) self.assertIsNone( tf_utils.get_shape_from_examples_path( test_utils.test_tmpfile(tfrecord_path_to_match)))
def testGetShapeFromExamplesPath(self, file_name_to_write, tfrecord_path_to_match): example = example_pb2.Example() valid_shape = [1, 2, 3] example.features.feature['image/shape'].int64_list.value.extend( valid_shape) output_file = test_utils.test_tmpfile(file_name_to_write) io_utils.write_tfrecords([example], output_file) tf_utils.get_shape_from_examples_path( test_utils.test_tmpfile(tfrecord_path_to_match))
def test_call_end2end_empty_first_shard(self): # Get only up to 10 examples. examples = list( io_utils.read_tfrecords( test_utils.GOLDEN_CALLING_EXAMPLES, max_records=10)) empty_first_file = test_utils.test_tmpfile('empty_1st_shard-00000-of-00002') io_utils.write_tfrecords([], empty_first_file) second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002') io_utils.write_tfrecords(examples, second_file) self.assertCallVariantsEmitsNRecordsForRandomGuess( test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))
def test_get_shape_from_examples_path(self, file_name_to_write, tfrecord_path_to_match): example = example_pb2.Example() valid_shape = [1, 2, 3] example.features.feature['image/shape'].int64_list.value.extend( valid_shape) output_file = test_utils.test_tmpfile(file_name_to_write) io_utils.write_tfrecords([example], output_file) ds = data_providers.DeepVariantDataSet( name='test_shape', source=test_utils.test_tmpfile(tfrecord_path_to_match), num_examples=1) self.assertEqual(valid_shape, ds.tensor_shape)
def test_call_variants_runs_on_gpus(self, model): call_variants.call_variants( examples_filename=test_utils.GOLDEN_CALLING_EXAMPLES, checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST, model=model, execution_hardware='accelerator', output_file=test_utils.test_tmpfile('zzz.tfrecord'))
def test_call_variants_with_empty_input(self): source_path = test_utils.test_tmpfile('empty.tfrecord') io_utils.write_tfrecords([], source_path) # Make sure that prepare_inputs don't crash on empty input. call_variants.prepare_inputs(source_path, modeling.get_model('random_guess'), batch_size=1)
def _test_dataset_config(filename, **kwargs): """Creates a DeepVariantDatasetConfig(**kwargs) and writes it to filename.""" dataset_config_pbtext_filename = test_utils.test_tmpfile(filename) dataset_config = deepvariant_pb2.DeepVariantDatasetConfig(**kwargs) data_providers.write_dataset_config_to_pbtxt( dataset_config, dataset_config_pbtext_filename) return dataset_config_pbtext_filename
def testModelShapes(self): # Builds a graph. v0 = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.float32, name='v0') v1 = tf.Variable([[[1], [2]], [[3], [4]], [[5], [6]]], dtype=tf.float32, name='v1') init_all_op = tf.initialize_all_variables() save = tf.train.Saver({'v0': v0, 'v1': v1}) save_path = test_utils.test_tmpfile('ckpt_for_debug_string') with tf.Session() as sess: sess.run(init_all_op) # Saves a checkpoint. save.save(sess, save_path) # Model shapes without any variable requests gives you all variables. self.assertEqual({ 'v0': (2, 3), 'v1': (3, 2, 1) }, tf_utils.model_shapes(save_path)) # Asking for v0 gives you only v0's shape. self.assertEqual({'v0': (2, 3)}, tf_utils.model_shapes(save_path, ['v0'])) # Asking for v1 gives you only v1's shape. self.assertEqual({'v1': (3, 2, 1)}, tf_utils.model_shapes(save_path, ['v1'])) # Verifies model_shapes() fails for non-existent tensors. with self.assertRaisesRegexp(KeyError, 'v3'): tf_utils.model_shapes(save_path, ['v3'])
def test_call_end2end(self, compressed_inputs): FLAGS.infile = make_golden_dataset(compressed_inputs) FLAGS.ref = test_utils.CHR20_FASTA FLAGS.outfile = test_utils.test_tmpfile('calls.vcf') FLAGS.nonvariant_site_tfrecord_path = ( test_utils.GOLDEN_POSTPROCESS_GVCF_INPUT) FLAGS.gvcf_outfile = test_utils.test_tmpfile('gvcf_calls.vcf') postprocess_variants.main(['postprocess_variants.py']) self.assertEqual( tf.gfile.FastGFile(FLAGS.outfile).readlines(), tf.gfile.FastGFile(test_utils.GOLDEN_POSTPROCESS_OUTPUT).readlines()) self.assertEqual( tf.gfile.FastGFile(FLAGS.gvcf_outfile).readlines(), tf.gfile.FastGFile( test_utils.GOLDEN_POSTPROCESS_GVCF_OUTPUT).readlines())
def test_call_variants_with_invalid_format(self, model, bad_format): # Read one good record from a valid file. example = next(io_utils.read_tfrecords(test_utils.GOLDEN_CALLING_EXAMPLES)) # Overwrite the image/format field to be an invalid value # (anything but 'raw'). example.features.feature['image/format'].bytes_list.value[0] = bad_format source_path = test_utils.test_tmpfile('make_examples_output.tfrecord') io_utils.write_tfrecords([example], source_path) outfile = test_utils.test_tmpfile('call_variants_invalid_format.tfrecord') with self.assertRaises(ValueError): call_variants.call_variants( examples_filename=source_path, checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST, model=model, output_file=outfile, batch_size=1, max_batches=1)
def _run(): call_variants.call_variants( examples_filename=test_utils.GOLDEN_CALLING_EXAMPLES, checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST, model=self.model, execution_hardware=hardware_env, max_batches=1, batch_size=1, output_file=test_utils.test_tmpfile('zzz.tfrecord'))
def test_call_end2end_with_empty_shards(self): # Get only up to 10 examples. examples = list( io_utils.read_tfrecords( test_utils.GOLDEN_CALLING_EXAMPLES, max_records=10)) # Write to 15 shards, which means there will be multiple empty shards. source_path = test_utils.test_tmpfile('sharded@{}'.format(15)) io_utils.write_tfrecords(examples, source_path) self.assertCallVariantsEmitsNRecordsForRandomGuess(source_path, len(examples))
def test_call_variants_with_no_shape(self, model): # Read one good record from a valid file. example = next(io_utils.read_tfrecords(test_utils.GOLDEN_CALLING_EXAMPLES)) # Remove image/shape. del example.features.feature['image/shape'] source_path = test_utils.test_tmpfile('make_examples_out_noshape.tfrecord') io_utils.write_tfrecords([example], source_path) with self.assertRaisesRegexp( ValueError, 'Invalid image/shape: we expect to find an image/shape ' 'field with length 3.'): call_variants.prepare_inputs(source_path, model, batch_size=1)
def test_catches_bad_flags(self): # Set all of the requested flag values. region = ranges.parse_literal('chr20:10,000,000-10,010,000') FLAGS.ref = test_utils.CHR20_FASTA FLAGS.reads = test_utils.CHR20_BAM FLAGS.candidates = test_utils.test_tmpfile('vsc.tfrecord') FLAGS.examples = test_utils.test_tmpfile('examples.tfrecord') FLAGS.regions = [ranges.to_literal(region)] FLAGS.partition_size = 1000 FLAGS.mode = 'training' FLAGS.truth_variants = test_utils.TRUTH_VARIANTS_VCF # This is the bad flag. FLAGS.confident_regions = '' with mock.patch.object(logging, 'error') as mock_logging,\ mock.patch.object(sys, 'exit') as mock_exit: make_examples.main(['make_examples.py']) mock_logging.assert_called_once_with( 'confident_regions is required when in training mode.') mock_exit.assert_called_once_with(errno.ENOENT)
def make_golden_dataset(compressed_inputs=False): if compressed_inputs: source_path = test_utils.test_tmpfile( 'golden.postprocess_single_site_input.tfrecord.gz') io_utils.write_tfrecords( io_utils.read_tfrecords(test_utils.GOLDEN_POSTPROCESS_INPUT, proto=deepvariant_pb2.CallVariantsOutput), source_path) else: source_path = test_utils.GOLDEN_POSTPROCESS_INPUT return source_path
def make_golden_dataset(compressed_inputs=False): if compressed_inputs: source_path = test_utils.test_tmpfile( 'make_golden_dataset.tfrecord.gz') io_utils.write_tfrecords( io_utils.read_tfrecords(test_utils.GOLDEN_TRAINING_EXAMPLES), source_path) else: source_path = test_utils.GOLDEN_TRAINING_EXAMPLES return data_providers.DeepVariantDataSet(name='labeled_golden', source=source_path, num_examples=49)
def test_call_variants_non_accelerated_execution_runs(self, execution_hardware): # This doesn't mock out the list_devices call so it's worth keeping # despite being very similar to the parameterized test below. outfile = test_utils.test_tmpfile('call_variants_cpu_only.tfrecord') call_variants.call_variants( examples_filename=test_utils.GOLDEN_CALLING_EXAMPLES, checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST, model=self.model, execution_hardware=execution_hardware, max_batches=1, batch_size=1, output_file=outfile)
def test_catches_bad_argv(self): # Define valid flags to ensure raise occurs due to argv issues. FLAGS.infile = make_golden_dataset(False) FLAGS.ref = test_utils.CHR20_FASTA FLAGS.outfile = test_utils.test_tmpfile('nonempty_outfile.vcf') with mock.patch.object(logging, 'error') as mock_logging,\ mock.patch.object(sys, 'exit') as mock_exit: postprocess_variants.main(['postprocess_variants.py', 'extra_arg']) mock_logging.assert_called_once_with( 'Command line parsing failure: postprocess_variants does not accept ' 'positional arguments but some are present on the command line: ' '"[\'postprocess_variants.py\', \'extra_arg\']".') mock_exit.assert_called_once_with(errno.ENOENT)
def assertCallVariantsEmitsNRecordsForRandomGuess(self, filename, num_examples): outfile = test_utils.test_tmpfile('call_variants.tfrecord') model = modeling.get_model('random_guess') call_variants.call_variants( examples_filename=filename, checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST, model=model, output_file=outfile, batch_size=4, max_batches=None) call_variants_outputs = list( io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput)) # Check that we have the right number of output protos. self.assertEqual(len(call_variants_outputs), num_examples)
def test_reading_sharded_dataset(self, compressed_inputs): golden_dataset = make_golden_dataset(compressed_inputs) n_shards = 3 sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards)) io_utils.write_tfrecords( io_utils.read_tfrecords(golden_dataset.source), sharded_path) config_file = _test_dataset_config( 'test_sharded.pbtxt', name='sharded_test', tfrecord_path=sharded_path, num_examples=golden_dataset.num_examples) self.assertDataSetExamplesMatchExpected( data_providers.get_dataset(config_file).get_slim_dataset(), golden_dataset)
def _run_tiny_training(self, model_name, dataset): with mock.patch( 'deepvariant.data_providers.get_dataset') as mock_get_dataset: mock_get_dataset.return_value = dataset FLAGS.train_dir = test_utils.test_tmpfile(model_name) FLAGS.batch_size = 2 FLAGS.model_name = model_name FLAGS.save_interval_secs = 0 FLAGS.number_of_steps = 1 FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt' FLAGS.start_from_checkpoint = '' model_train.parse_and_run() # We have a checkpoint after training. mock_get_dataset.assert_called_once_with( FLAGS.dataset_config_pbtxt) self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir))
def test_prepare_inputs(self, filename, expand_to_file_pattern): source_path = test_utils.test_tmpfile(filename) io_utils.write_tfrecords(self.examples, source_path) if expand_to_file_pattern: # Transform foo@3 to foo-?????-of-00003. source_path = io_utils.NormalizeToShardedFilePattern(source_path) with self.test_session() as sess: _, variants, _ = call_variants.prepare_inputs( source_path, self.model, batch_size=1) sess.run(tf.local_variables_initializer()) sess.run(tf.global_variables_initializer()) seen_variants = [] try: while True: seen_variants.extend(sess.run(variants)) except tf.errors.OutOfRangeError: pass self.assertItemsEqual(self.variants, variantutils.decode_variants(seen_variants))
def test_call_end2end(self, model, shard_inputs, include_debug_info): FLAGS.include_debug_info = include_debug_info examples = list(io_utils.read_tfrecords(test_utils.GOLDEN_CALLING_EXAMPLES)) if shard_inputs: # Create a sharded version of our golden examples. source_path = test_utils.test_tmpfile('sharded@{}'.format(3)) io_utils.write_tfrecords(examples, source_path) else: source_path = test_utils.GOLDEN_CALLING_EXAMPLES batch_size = 4 if model.name == 'random_guess': # For the random guess model we can run everything. max_batches = None else: # For all other models we only run a single batch for inference. max_batches = 1 outfile = test_utils.test_tmpfile('call_variants.tfrecord') call_variants.call_variants( examples_filename=source_path, checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST, model=model, output_file=outfile, batch_size=batch_size, max_batches=max_batches) call_variants_outputs = list( io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput)) # Check that we have the right number of output protos. self.assertEqual( len(call_variants_outputs), batch_size * max_batches if max_batches else len(examples)) # Check that our CallVariantsOutput (CVO) have the following critical # properties: # - we have one CVO for each example we processed. # - the variant in the CVO is exactly what was in the example. # - the alt_allele_indices of the CVO match those of its corresponding # example. # - there are 3 genotype probabilities and these are between 0.0 and 1.0. # We can only do this test when processing all of the variants (max_batches # is None), since we processed all of the examples with that model. if max_batches is None: self.assertItemsEqual([cvo.variant for cvo in call_variants_outputs], [tf_utils.example_variant(ex) for ex in examples]) # Check the CVO debug_info: not filled if include_debug_info is False; # else, filled by logic based on CVO. if not include_debug_info: for cvo in call_variants_outputs: self.assertEqual(cvo.debug_info, deepvariant_pb2.CallVariantsOutput.DebugInfo()) else: for cvo in call_variants_outputs: self.assertEqual(cvo.debug_info.has_insertion, variantutils.has_insertion(cvo.variant)) self.assertEqual(cvo.debug_info.has_deletion, variantutils.has_deletion(cvo.variant)) self.assertEqual(cvo.debug_info.is_snp, variantutils.is_snp( cvo.variant)) self.assertEqual(cvo.debug_info.predicted_label, np.argmax(cvo.genotype_probabilities)) def example_matches_call_variants_output(example, call_variants_output): return (tf_utils.example_variant(example) == call_variants_output.variant and tf_utils.example_alt_alleles_indices( example) == call_variants_output.alt_allele_indices.indices) for call_variants_output in call_variants_outputs: # Find all matching examples. matches = [ ex for ex in examples if example_matches_call_variants_output(ex, call_variants_output) ] # We should have exactly one match. self.assertEqual(len(matches), 1) example = matches[0] # Check that we've faithfully copied in the alt alleles (though currently # as implemented we find our example using this information so it cannot # fail). Included here in case that changes in the future. self.assertEqual( list(tf_utils.example_alt_alleles_indices(example)), list(call_variants_output.alt_allele_indices.indices)) # We should have exactly three genotype probabilities (assuming our # ploidy == 2). self.assertEqual(len(call_variants_output.genotype_probabilities), 3) # These are probabilities so they should be between 0 and 1. self.assertTrue( 0 <= gp <= 1 for gp in call_variants_output.genotype_probabilities)
def test_realigner_diagnostics(self, enabled, emit_reads): # Make sure that by default we aren't emitting any diagnostic outputs. dx_dir = test_utils.test_tmpfile('dx') region_str = 'chr20:10046179-10046188' region = ranges.parse_literal(region_str) assembled_region_str = 'chr20:10046109-10046257' reads = _get_reads(region) self.config = realigner.realigner_config(FLAGS) self.config.diagnostics.enabled = enabled self.config.diagnostics.output_root = dx_dir self.config.diagnostics.emit_realigned_reads = emit_reads self.reads_realigner = realigner.Realigner(self.config, self.ref_reader) _, realigned_reads = self.reads_realigner.realign_reads(reads, region) self.reads_realigner.diagnostic_logger.close( ) # Force close all resources. if not enabled: # Make sure our diagnostic output isn't emitted. self.assertFalse(tf.gfile.Exists(dx_dir)) else: # Our root directory exists. self.assertTrue(tf.gfile.IsDirectory(dx_dir)) # We expect a realigner_metrics.csv in our rootdir with 1 entry in it. metrics_file = os.path.join( dx_dir, self.reads_realigner.diagnostic_logger.metrics_filename) self.assertTrue(tf.gfile.Exists(metrics_file)) with tf.gfile.FastGFile(metrics_file) as fin: rows = list(csv.DictReader(fin)) self.assertEqual(len(rows), 1) self.assertEqual(set(rows[0].keys()), {'window', 'k', 'n_haplotypes', 'time'}) self.assertEqual(rows[0]['window'], assembled_region_str) self.assertEqual(int(rows[0]['k']), 25) self.assertTrue(int(rows[0]['n_haplotypes']), 2) # Check that our runtime is reasonable (greater than 0, less than 10 s). self.assertTrue(0.0 < float(rows[0]['time']) < 10.0) # As does the subdirectory for this region. region_subdir = os.path.join(dx_dir, assembled_region_str) self.assertTrue(tf.gfile.IsDirectory(region_subdir)) # We always have a graph.dot self.assertTrue( tf.gfile.Exists( os.path.join( region_subdir, self.reads_realigner.diagnostic_logger. graph_filename))) reads_file = os.path.join( dx_dir, region_str, self.reads_realigner.diagnostic_logger. realigned_reads_filename) if emit_reads: self.assertTrue(tf.gfile.Exists(reads_file)) reads_from_dx = io_utils.read_tfrecords( reads_file, reads_pb2.Read) self.assertCountEqual(reads_from_dx, realigned_reads) else: self.assertFalse(tf.gfile.Exists(reads_file))
def test_make_examples_end2end(self, mode, num_shards): self.assertIn(mode, {'calling', 'training'}) region = ranges.parse_literal('chr20:10,000,000-10,010,000') FLAGS.ref = test_utils.CHR20_FASTA FLAGS.reads = test_utils.CHR20_BAM FLAGS.candidates = test_utils.test_tmpfile( _sharded('vsc.tfrecord', num_shards)) FLAGS.examples = test_utils.test_tmpfile( _sharded('examples.tfrecord', num_shards)) FLAGS.regions = [ranges.to_literal(region)] FLAGS.partition_size = 1000 FLAGS.mode = mode if mode == 'calling': FLAGS.gvcf = test_utils.test_tmpfile( _sharded('gvcf.tfrecord', num_shards)) else: FLAGS.truth_variants = test_utils.TRUTH_VARIANTS_VCF FLAGS.confident_regions = test_utils.CONFIDENT_REGIONS_BED for task_id in range(max(num_shards, 1)): FLAGS.task = task_id options = make_examples.default_options(add_flags=True) make_examples.make_examples_runner(options) # Test that our candidates are reasonable, calling specific helper functions # to check lots of properties of the output. candidates = _sort_candidates( io_utils.read_tfrecords(FLAGS.candidates, proto=deepvariant_pb2.DeepVariantCall)) self.verify_deepvariant_calls(candidates, options) self.verify_variants([call.variant for call in candidates], region, options, is_gvcf=False) # Verify that the variants in the examples are all good. examples = self.verify_examples(FLAGS.examples, region, options, verify_labels=mode == 'training') example_variants = [tf_utils.example_variant(ex) for ex in examples] self.verify_variants(example_variants, region, options, is_gvcf=False) # Verify the integrity of the examples and then check that they match our # golden labeled examples. Note we expect the order for both training and # calling modes to produce deterministic order because we fix the random # seed. if mode == 'calling': golden_file = _sharded(test_utils.GOLDEN_CALLING_EXAMPLES, num_shards) else: golden_file = _sharded(test_utils.GOLDEN_TRAINING_EXAMPLES, num_shards) self.assertDeepVariantExamplesEqual( examples, list(io_utils.read_tfrecords(golden_file))) if mode == 'calling': nist_reader = genomics_io.make_vcf_reader( test_utils.TRUTH_VARIANTS_VCF) nist_variants = list(nist_reader.query(region)) self.verify_nist_concordance(example_variants, nist_variants) # Check the quality of our generated gvcf file. gvcfs = _sort_variants( io_utils.read_tfrecords(FLAGS.gvcf, proto=variants_pb2.Variant)) self.verify_variants(gvcfs, region, options, is_gvcf=True) self.verify_contiguity(gvcfs, region)