def setUp(self): tfrecord_file = test_utils.genomics_core_testdata( 'test_features.gff.tfrecord') self.records = list( tfrecord.read_tfrecords(tfrecord_file, proto=gff_pb2.GffRecord)) self.header = gff_pb2.GffHeader( sequence_regions=[ranges.make_range('ctg123', 0, 1497228)])
def examples_to_variants(examples_path, max_records=None): """Yields Variant protos from the examples in examples_path. This function reads in tf.Examples produced by DeepVariant from examples_path, which may contain a sharded spec, sorts them, selects a representive example when there are multiple versions representing different alt_alleles, and yields the example_variant field from those examples. Args: examples_path: str. Path, or sharded spec, to labeled tf.Examples produced by DeepVariant in training mode. max_records: int or None. Maximum number of records to read, or None, to read all of the records. Yields: nucleus.protos.Variant protos in coordinate-sorted order. Raises: ValueError: if we find a Variant in any example that doesn't have genotypes. """ examples = tfrecord.read_tfrecords(examples_path, max_records=max_records) variants = sorted( (tf_utils.example_variant(example) for example in examples), key=variant_utils.variant_range_tuple) for _, group in itertools.groupby(variants, variant_utils.variant_range_tuple): variant = next(group) if not variantcall_utils.has_genotypes( variant_utils.only_call(variant)): raise ValueError(( 'Variant {} does not have any genotypes. This tool only works with ' 'variants that have been labeled.').format( variant_utils.variant_key(variant))) yield variant
def _transform_call_variants_output_to_variants(input_sorted_tfrecord_path, qual_filter, multi_allelic_qual_filter, sample_name): """Yields Variant protos in sorted order from CallVariantsOutput protos. Variants present in the input TFRecord are converted to Variant protos, with the following filters applied: 1) variants are omitted if their quality is lower than the `qual_filter` threshold. 2) multi-allelic variants omit individual alleles whose qualities are lower than the `multi_allelic_qual_filter` threshold. Args: input_sorted_tfrecord_path: str. TFRecord format file containing sorted CallVariantsOutput protos. qual_filter: double. The qual value below which to filter variants. multi_allelic_qual_filter: double. The qual value below which to filter multi-allelic variants. sample_name: str. Sample name to write to VCF file. Yields: Variant protos in sorted order representing the CallVariantsOutput calls. """ for _, group in itertools.groupby( tfrecord.read_tfrecords(input_sorted_tfrecord_path, proto=deepvariant_pb2.CallVariantsOutput), lambda x: variant_utils.variant_range(x.variant)): outputs = _sort_grouped_variants(group) canonical_variant, predictions = merge_predictions( outputs, multi_allelic_qual_filter) variant = add_call_to_variant(canonical_variant, predictions, qual_filter=qual_filter, sample_name=sample_name) yield variant
def test_make_examples_training_end2end_with_alt_aligned_pileup( self, alt_align, expected_shape): region = ranges.parse_literal('chr20:10,000,000-10,010,000') FLAGS.regions = [ranges.to_literal(region)] FLAGS.ref = testdata.CHR20_FASTA FLAGS.reads = testdata.CHR20_BAM FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord')) FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord')) FLAGS.partition_size = 1000 FLAGS.mode = 'training' FLAGS.gvcf_gq_binsize = 5 FLAGS.alt_aligned_pileup = alt_align # This is the only input change. FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED options = make_examples.default_options(add_flags=True) # Run make_examples with the flags above. make_examples_core.make_examples_runner(options) # Check the output for shape and against the golden file. if alt_align == 'rows': golden_file = _sharded(testdata.ALT_ALIGNED_ROWS_EXAMPLES) elif alt_align == 'diff_channels': golden_file = _sharded(testdata.ALT_ALIGNED_DIFF_CHANNELS_EXAMPLES) else: raise ValueError("Golden data doesn't exist for this alt_align option: " '{}'.format(alt_align)) # Verify that the variants in the examples are all good. examples = self.verify_examples( FLAGS.examples, region, options, verify_labels=True) self.assertDeepVariantExamplesEqual( examples, list(tfrecord.read_tfrecords(golden_file))) # Pileup image should have 3 rows of height 100, so resulting height is 300. self.assertEqual(decode_example(examples[0])['image/shape'], expected_shape)
def test_make_examples_end2end_vcf_candidate_importer(self, mode): FLAGS.variant_caller = 'vcf_candidate_importer' FLAGS.ref = testdata.CHR20_FASTA FLAGS.reads = testdata.CHR20_BAM FLAGS.candidates = test_utils.test_tmpfile( _sharded('vcf_candidate_importer.{}.tfrecord'.format(mode))) FLAGS.examples = test_utils.test_tmpfile( _sharded('vcf_candidate_importer.examples.{}.tfrecord'.format(mode))) FLAGS.mode = mode if mode == 'calling': golden_file = _sharded( testdata.GOLDEN_VCF_CANDIDATE_IMPORTER_CALLING_EXAMPLES) FLAGS.proposed_variants = testdata.VCF_CANDIDATE_IMPORTER_VARIANTS # Adding the following flags to match how the testdata was created. FLAGS.regions = 'chr20:59,777,000-60,000,000' FLAGS.realign_reads = False else: golden_file = _sharded( testdata.GOLDEN_VCF_CANDIDATE_IMPORTER_TRAINING_EXAMPLES) FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF options = make_examples.default_options(add_flags=True) make_examples_core.make_examples_runner(options) # Verify that the variants in the examples are all good. examples = self.verify_examples( FLAGS.examples, None, options, verify_labels=mode == 'training') self.assertDeepVariantExamplesEqual( examples, list(tfrecord.read_tfrecords(golden_file))) self.assertEqual( decode_example(examples[0])['image/shape'], [100, 221, dv_constants.PILEUP_NUM_CHANNELS])
def test_make_examples_with_allele_frequency(self, mode): FLAGS.mode = 'calling' FLAGS.ref = testdata.GRCH38_FASTA FLAGS.reads = testdata.GRCH38_CHR20_AND_21_BAM num_shards = 1 FLAGS.examples = test_utils.test_tmpfile( _sharded('examples.tfrecord', num_shards)) region = ranges.parse_literal('chr20:61001-62000') FLAGS.use_allele_frequency = True FLAGS.regions = [ranges.to_literal(region)] if mode == 'one vcf': FLAGS.population_vcfs = testdata.AF_VCF_CHR20_AND_21 elif mode == 'two vcfs': FLAGS.population_vcfs = ' '.join( [testdata.AF_VCF_CHR20, testdata.AF_VCF_CHR21]) else: raise ValueError('Invalid mode for parameterized test.') options = make_examples.default_options(add_flags=True) # Run make_examples with the flags above. make_examples_core.make_examples_runner(options) # Verify that the variants in the examples are all good. examples = self.verify_examples( FLAGS.examples, region, options, verify_labels=False) # Pileup images should have one extra channel. self.assertEqual([100, 221, dv_constants.PILEUP_NUM_CHANNELS + 1], decode_example(examples[0])['image/shape']) # Test there is something in the added channel. # Values capture whether each loci has been seen in the observed examples. population_matched_loci = { 'chr20:61539_A': False, 'chr20:61634_G': False, 'chr20:61644_G': False } for example in examples: locus_id = vis.locus_id_from_variant(vis.variant_from_example(example)) if locus_id in population_matched_loci.keys(): channels = vis.channels_from_example(example) self.assertGreater( np.sum(channels[dv_constants.PILEUP_NUM_CHANNELS]), 0, msg='There should be ' 'something in the %s-th channel for variant ' '%s' % (dv_constants.PILEUP_NUM_CHANNELS + 1, locus_id)) population_matched_loci[locus_id] = True self.assertTrue( all(population_matched_loci.values()), msg='Check that all ' '3 sample loci appeared in the examples.') # Check against the golden file (same for both modes). golden_file = _sharded(testdata.GOLDEN_ALLELE_FREQUENCY_EXAMPLES) examples_from_golden = list(tfrecord.read_tfrecords(golden_file)) self.assertDeepVariantExamplesEqual(examples_from_golden, examples)
def test_call_end2end_with_empty_shards(self): # Get only up to 10 examples. examples = list( tfrecord.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES, max_records=10)) # Write to 15 shards, which means there will be multiple empty shards. source_path = test_utils.test_tmpfile('sharded@{}'.format(15)) tfrecord.write_tfrecords(examples, source_path) self.assertCallVariantsEmitsNRecordsForRandomGuess( source_path, len(examples))
def make_golden_dataset(compressed_inputs=False): if compressed_inputs: source_path = test_utils.test_tmpfile( 'golden.postprocess_single_site_input.tfrecord.gz') tfrecord.write_tfrecords( tfrecord.read_tfrecords( testdata.GOLDEN_POSTPROCESS_INPUT, proto=deepvariant_pb2.CallVariantsOutput), source_path) else: source_path = testdata.GOLDEN_POSTPROCESS_INPUT return source_path
def test_call_end2end_empty_first_shard(self): # Get only up to 10 examples. examples = list( tfrecord.read_tfrecords( testdata.GOLDEN_CALLING_EXAMPLES, max_records=10)) empty_first_file = test_utils.test_tmpfile('empty_1st_shard-00000-of-00002') tfrecord.write_tfrecords([], empty_first_file) second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002') tfrecord.write_tfrecords(examples, second_file) self.assertCallVariantsEmitsNRecordsForRandomGuess( test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))
def test_read_tfrecords_max_records(self, filename, max_records): protos, path = self.write_test_protos(filename) # Create our generator of records from read_tfrecords. if max_records is None: expected_n = len(protos) else: expected_n = min(max_records, len(protos)) actual = tfrecord.read_tfrecords( path, reference_pb2.ContigInfo, max_records=max_records) self.assertLen(list(actual), expected_n)
def _call_end2end_helper(self, examples_path, model, shard_inputs): examples = list(tfrecord.read_tfrecords(examples_path)) if shard_inputs: # Create a sharded version of our golden examples. source_path = test_utils.test_tmpfile('sharded@{}'.format(3)) tfrecord.write_tfrecords(examples, source_path) else: source_path = examples_path # If we point the test at a headless server, it will often be 2x2, # which has 8 replicas. Otherwise a smaller batch size is fine. if FLAGS.use_tpu: batch_size = 8 else: batch_size = 4 if model.name == 'random_guess': # For the random guess model we can run everything. max_batches = None else: # For all other models we only run a single batch for inference. max_batches = 1 outfile = test_utils.test_tmpfile('call_variants.tfrecord') call_variants.call_variants( examples_filename=source_path, checkpoint_path=_LEAVE_MODEL_UNINITIALIZED, model=model, output_file=outfile, batch_size=batch_size, max_batches=max_batches, master='', use_tpu=FLAGS.use_tpu, ) call_variants_outputs = list( tfrecord.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput)) return call_variants_outputs, examples, batch_size, max_batches
def test_call_variants_with_no_shape(self, model): # Read one good record from a valid file. example = next(tfrecord.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES)) # Remove image/shape. del example.features.feature['image/shape'] source_path = test_utils.test_tmpfile('make_examples_out_noshape.tfrecord') tfrecord.write_tfrecords([example], source_path) with six.assertRaisesRegex( self, ValueError, 'Invalid image/shape: we expect to find an image/shape ' 'field with length 3.'): ds = call_variants.prepare_inputs(source_path) _ = list(_get_infer_batches(ds, model=model, batch_size=1))
def test_reading_sharded_input_with_empty_shards_does_not_crash(self): valid_variants = tfrecord.read_tfrecords( testdata.GOLDEN_POSTPROCESS_INPUT, proto=deepvariant_pb2.CallVariantsOutput) empty_shard_one = test_utils.test_tmpfile( 'reading_empty_shard.tfrecord-00000-of-00002') none_empty_shard_two = test_utils.test_tmpfile( 'reading_empty_shard.tfrecord-00001-of-00002') tfrecord.write_tfrecords([], empty_shard_one) tfrecord.write_tfrecords(valid_variants, none_empty_shard_two) FLAGS.infile = test_utils.test_tmpfile('reading_empty_shard.tfrecord@2') FLAGS.ref = testdata.CHR20_FASTA FLAGS.outfile = test_utils.test_tmpfile('calls_reading_empty_shard.vcf') postprocess_variants.main(['postprocess_variants.py'])
def test_make_read_writer_tfrecords(self): outfile = test_utils.test_tmpfile('test.tfrecord') writer = sam.SamWriter(outfile, header=self.header) # Test that the writer is a context manager and that we can write a read to # it. with writer: writer.write(self.read1) writer.write(self.read2) # Our output should have exactly one read in it. self.assertEqual([self.read1, self.read2], list( tfrecord.read_tfrecords(outfile, proto=reads_pb2.Read)))
def test_read_write_tfrecords(self, filename): protos, path = self.write_test_protos(filename) # Create our generator of records from read_tfrecords. reader = tfrecord.read_tfrecords(path, reference_pb2.ContigInfo) # Make sure it's actually a generator. self.assertEqual(type(reader), types.GeneratorType) # Check the round-trip contents. if '@' in filename: # Sharded outputs are striped across shards, so order isn't preserved. self.assertCountEqual(protos, reader) else: self.assertEqual(protos, list(reader))
def _transform_call_variants_output_to_variants(input_sorted_tfrecord_path, qual_filter, multi_allelic_qual_filter, sample_name, group_variants, use_multiallelic_model): """Yields Variant protos in sorted order from CallVariantsOutput protos. Variants present in the input TFRecord are converted to Variant protos, with the following filters applied: 1) variants are omitted if their quality is lower than the `qual_filter` threshold. 2) multi-allelic variants omit individual alleles whose qualities are lower than the `multi_allelic_qual_filter` threshold. Args: input_sorted_tfrecord_path: str. TFRecord format file containing sorted CallVariantsOutput protos. qual_filter: double. The qual value below which to filter variants. multi_allelic_qual_filter: double. The qual value below which to filter multi-allelic variants. sample_name: str. Sample name to write to VCF file. group_variants: bool. If true, group variants that have same start and end position. use_multiallelic_model: if True, use a specialized model for genotype resolution of multiallelic cases with two alts. Yields: Variant protos in sorted order representing the CallVariantsOutput calls. """ multiallelic_model = get_multiallelic_model( use_multiallelic_model=use_multiallelic_model) group_fn = None if group_variants: group_fn = lambda x: variant_utils.variant_range(x.variant) for _, group in itertools.groupby( tfrecord.read_tfrecords( input_sorted_tfrecord_path, proto=deepvariant_pb2.CallVariantsOutput), group_fn): outputs = _sort_grouped_variants(group) canonical_variant, predictions = merge_predictions( outputs, multi_allelic_qual_filter, multiallelic_model=multiallelic_model) variant = add_call_to_variant( canonical_variant, predictions, qual_filter=qual_filter, sample_name=sample_name) yield variant
def assertTfDataSetExamplesMatchExpected(self, input_fn, expected_dataset, use_tpu=False, workaround_list_files=False): # Note that we use input_fn to get an iterator, while we use # expected_dataset to get a filename, even though they are the same # type (DeepVariantInput), and may even be the same object. with tf.compat.v1.Session() as sess: params = {'batch_size': 1} batch_feed = tf.compat.v1.data.make_one_shot_iterator( input_fn(params)).get_next() sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) seen = [] while True: try: features, _ = sess.run(batch_feed) except tf.errors.OutOfRangeError: break locus = features['locus'][0] if use_tpu: locus = tf_utils.int_tensor_to_string(locus) # NB, this looks like: array(['chr20:10001019-10001019'], dtype=object) seen.append(locus) if workaround_list_files: # This really only works for loci, because those are string valued and # are expected to show up in sorted order. For arbitrary data that's # not true. In prod we have the version of tf that lets us turn off # shuffling so this path is skipped, but kokoro hits this. seen = sorted(seen) expected_loci = [ example.features.feature['locus'].bytes_list.value[0] for example in tfrecord.read_tfrecords(expected_dataset.input_file_spec) ] self.assertLen(expected_loci, expected_dataset.num_examples) if seen != expected_loci: print('\n\nlen expected seen', len(expected_loci), len(seen)) print('\n\nexpected=', expected_loci) print('\n\nseen=', seen) self.assertEqual(expected_loci, seen) # Note that this expected shape comes from the golden dataset. If the data # is remade in the future, the values might need to be modified accordingly. self.assertEqual(dv_constants.PILEUP_DEFAULT_DIMS, expected_dataset.tensor_shape)
def make_golden_dataset(compressed_inputs=False, mode=tf.estimator.ModeKeys.EVAL, use_tpu=False): if compressed_inputs: source_path = test_utils.test_tmpfile('make_golden_dataset.tfrecord.gz') tfrecord.write_tfrecords( tfrecord.read_tfrecords(testdata.GOLDEN_TRAINING_EXAMPLES), source_path) else: source_path = testdata.GOLDEN_TRAINING_EXAMPLES return data_providers.get_input_fn_from_filespec( input_file_spec=source_path, num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES, name='labeled_golden', mode=mode, tensor_shape=None, use_tpu=use_tpu)
def test_writing_canned_records(self): """Tests writing all the records that are 'canned' in our tfrecord file.""" # This file is in TFRecord format. tfrecord_file = test_utils.genomics_core_testdata( 'test_features.gff.tfrecord') writer_options = gff_pb2.GffWriterOptions() gff_records = list( tfrecord.read_tfrecords(tfrecord_file, proto=gff_pb2.GffRecord)) out_fname = test_utils.test_tmpfile('output.gff') with gff_writer.GffWriter.to_file(out_fname, self.header, writer_options) as writer: for record in gff_records: writer.write(record) with open(out_fname) as f: self.assertEqual(f.readlines(), self.expected_gff_content)
def assertCallVariantsEmitsNRecordsForInceptionV3(self, filename, num_examples): outfile = test_utils.test_tmpfile('inception_v3.call_variants.tfrecord') model = modeling.get_model('inception_v3') checkpoint_path = _LEAVE_MODEL_UNINITIALIZED call_variants.call_variants( examples_filename=filename, checkpoint_path=checkpoint_path, model=model, output_file=outfile, batch_size=4, max_batches=None) call_variants_outputs = list( tfrecord.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput)) # Check that we have the right number of output protos. self.assertEqual(len(call_variants_outputs), num_examples)
def assertCallVariantsEmitsNRecordsForRandomGuess(self, filename, num_examples): checkpoint_path = _LEAVE_MODEL_UNINITIALIZED outfile = test_utils.test_tmpfile('call_variants.tfrecord') model = modeling.get_model('random_guess') call_variants.call_variants(examples_filename=filename, checkpoint_path=checkpoint_path, model=model, output_file=outfile, batch_size=4, max_batches=None, master='', use_tpu=FLAGS.use_tpu) call_variants_outputs = list( tfrecord.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput)) # Check that we have the right number of output protos. self.assertEqual(len(call_variants_outputs), num_examples)
def test_writing_canned_records(self): """Tests writing all the variants that are 'canned' in our tfrecord file.""" # This file is in TFRecord format. tfrecord_file = test_utils.genomics_core_testdata( 'test_reads.fastq.tfrecord') writer_options = fastq_pb2.FastqWriterOptions() fastq_records = list( tfrecord.read_tfrecords(tfrecord_file, proto=fastq_pb2.FastqRecord)) out_fname = test_utils.test_tmpfile('output.fastq') with fastq_writer.FastqWriter.to_file(out_fname, writer_options) as writer: for record in fastq_records: writer.write(record) with gfile.Open(out_fname, 'r') as f: self.assertEqual(f.readlines(), self.expected_fastq_content)
def test_writing_canned_records(self): """Tests writing all the records that are 'canned' in our tfrecord file.""" # This file is in TFRecord format. tfrecord_file = test_utils.genomics_core_testdata( 'test_regions.bed.tfrecord') header = bed_pb2.BedHeader(num_fields=12) writer_options = bed_pb2.BedWriterOptions() bed_records = list( tfrecord.read_tfrecords(tfrecord_file, proto=bed_pb2.BedRecord)) out_fname = test_utils.test_tmpfile('output.bed') with bed_writer.BedWriter.to_file(out_fname, header, writer_options) as writer: for record in bed_records: writer.write(record) with gfile.Open(out_fname, 'r') as f: self.assertEqual(f.readlines(), self.expected_bed_content)
def test_call_variants_with_invalid_format(self, model, bad_format): # Read one good record from a valid file. example = next(tfrecord.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES)) # Overwrite the image/format field to be an invalid value # (anything but 'raw'). example.features.feature['image/format'].bytes_list.value[0] = bad_format source_path = test_utils.test_tmpfile('make_examples_output.tfrecord') tfrecord.write_tfrecords([example], source_path) outfile = test_utils.test_tmpfile('call_variants_invalid_format.tfrecord') with self.assertRaises(ValueError): call_variants.call_variants( examples_filename=source_path, checkpoint_path=_LEAVE_MODEL_UNINITIALIZED, model=model, output_file=outfile, batch_size=1, max_batches=1, use_tpu=FLAGS.use_tpu)
def test_make_examples_with_variant_selection(self, select_types, expected_count, keep_legacy_behavior=False): if select_types is not None: FLAGS.select_variant_types = select_types region = ranges.parse_literal('chr20:10,000,000-10,010,000') FLAGS.regions = [ranges.to_literal(region)] FLAGS.ref = testdata.CHR20_FASTA FLAGS.reads = testdata.CHR20_BAM FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord')) FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord')) FLAGS.partition_size = 1000 FLAGS.mode = 'calling' FLAGS.keep_legacy_allele_counter_behavior = keep_legacy_behavior options = make_examples.default_options(add_flags=True) make_examples_core.make_examples_runner(options) candidates = list(tfrecord.read_tfrecords(FLAGS.candidates)) self.assertLen(candidates, expected_count)
def test_reading_sharded_dataset(self, compressed_inputs, use_tpu): golden_dataset = make_golden_dataset(compressed_inputs, use_tpu=use_tpu) n_shards = 3 sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards)) tfrecord.write_tfrecords( tfrecord.read_tfrecords(golden_dataset.input_file_spec), sharded_path) config_file = _test_dataset_config( 'test_sharded.pbtxt', name='sharded_test', tfrecord_path=sharded_path, num_examples=golden_dataset.num_examples) self.assertTfDataSetExamplesMatchExpected( data_providers.get_input_fn_from_dataset( config_file, mode=tf.estimator.ModeKeys.EVAL), golden_dataset, # workaround_list_files is needed because wildcards, and so sharded # files, are nondeterministicly ordered (for now). workaround_list_files=True, )
def verify_examples(self, examples_filename, region, options, verify_labels): # Do some simple structural checks on the tf.Examples in the file. expected_features = [ 'variant/encoded', 'locus', 'image/format', 'image/encoded', 'alt_allele_indices/encoded' ] if verify_labels: expected_features += ['label'] examples = list(tfrecord.read_tfrecords(examples_filename)) for example in examples: for label_feature in expected_features: self.assertIn(label_feature, example.features.feature) # pylint: disable=g-explicit-length-test self.assertNotEmpty(tf_utils.example_alt_alleles_indices(example)) # Check that the variants in the examples are good. variants = [tf_utils.example_variant(x) for x in examples] self.verify_variants(variants, region, options, is_gvcf=False) return examples
def test_make_examples_training_end2end_with_customized_classes_labeler(self): FLAGS.labeler_algorithm = 'customized_classes_labeler' FLAGS.customized_classes_labeler_classes_list = 'ref,class1,class2' FLAGS.customized_classes_labeler_info_field_name = 'type' region = ranges.parse_literal('chr20:10,000,000-10,004,000') FLAGS.regions = [ranges.to_literal(region)] FLAGS.ref = testdata.CHR20_FASTA FLAGS.reads = testdata.CHR20_BAM FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord')) FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord')) FLAGS.partition_size = 1000 FLAGS.mode = 'training' FLAGS.gvcf_gq_binsize = 5 FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF_WITH_TYPES FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED options = make_examples.default_options(add_flags=True) make_examples_core.make_examples_runner(options) golden_file = _sharded(testdata.CUSTOMIZED_CLASSES_GOLDEN_TRAINING_EXAMPLES) # Verify that the variants in the examples are all good. examples = self.verify_examples( FLAGS.examples, region, options, verify_labels=True) self.assertDeepVariantExamplesEqual( examples, list(tfrecord.read_tfrecords(golden_file)))
def get_one_example_from_examples_path(source, proto=None): """Get the first record from `source`. Args: source: str. A pattern or a comma-separated list of patterns that represent file names. proto: A proto class. proto.FromString() will be called on each serialized record in path to parse it. Returns: The first record, or None. """ files = sharded_file_utils.glob_list_sharded_file_patterns(source) if not files: raise ValueError( 'Cannot find matching files with the pattern "{}"'.format(source)) for f in files: try: return tfrecord.read_tfrecords(f, proto=proto).next() except StopIteration: # Getting a StopIteration from one next() means source_path is empty. # Move on to the next one to try to get one example. pass return None
def test_writing_canned_variants(self): """Tests writing all the variants that are 'canned' in our tfrecord file.""" # This file is in the TF record format tfrecord_file = test_utils.genomics_core_testdata( 'test_samples.vcf.golden.tfrecord') writer_options = variants_pb2.VcfWriterOptions() header = variants_pb2.VcfHeader( contigs=[ reference_pb2.ContigInfo(name='chr1', n_bases=248956422), reference_pb2.ContigInfo(name='chr2', n_bases=242193529), reference_pb2.ContigInfo(name='chr3', n_bases=198295559), reference_pb2.ContigInfo(name='chrX', n_bases=156040895) ], sample_names=['NA12878_18_99'], filters=[ variants_pb2.VcfFilterInfo( id='PASS', description='All filters passed'), variants_pb2.VcfFilterInfo(id='LowQual', description=''), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL95.00to96.00'), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL96.00to97.00'), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL97.00to99.00'), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.00to99.50'), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.50to99.90'), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.90to99.95'), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.95to100.00+'), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.95to100.00'), variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.50to99.60'), variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.60to99.80'), variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.80to99.90'), variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.90to99.95'), variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00+'), variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00'), ], infos=[ variants_pb2.VcfInfo( id='END', number='1', type='Integer', description='Stop position of the interval') ], formats=[ variants_pb2.VcfFormatInfo( id='GT', number='1', type='String', description='Genotype'), variants_pb2.VcfFormatInfo( id='GQ', number='1', type='Integer', description='Genotype Quality'), variants_pb2.VcfFormatInfo( id='DP', number='1', type='Integer', description='Read depth of all passing filters reads.'), variants_pb2.VcfFormatInfo( id='MIN_DP', number='1', type='Integer', description='Minimum DP observed within the GVCF block.'), variants_pb2.VcfFormatInfo( id='AD', number='R', type='Integer', description= 'Read depth of all passing filters reads for each allele.'), variants_pb2.VcfFormatInfo( id='VAF', number='A', type='Float', description='Variant allele fractions.'), variants_pb2.VcfFormatInfo( id='PL', number='G', type='Integer', description='Genotype likelihoods, Phred encoded'), ], ) variant_records = list( tfrecord.read_tfrecords(tfrecord_file, proto=variants_pb2.Variant)) out_fname = test_utils.test_tmpfile('output.vcf') with vcf_writer.VcfWriter.to_file(out_fname, header, writer_options) as writer: for record in variant_records[:5]: writer.write(record) # Check: are the variants written as expected? # pylint: disable=line-too-long expected_vcf_content = [ '##fileformat=VCFv4.2\n', '##FILTER=<ID=PASS,Description="All filters passed">\n', '##FILTER=<ID=LowQual,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL95.00to96.00,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL96.00to97.00,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL97.00to99.00,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL99.00to99.50,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL99.50to99.90,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL99.90to99.95,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00+,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00,Description="">\n', '##FILTER=<ID=VQSRTrancheSNP99.50to99.60,Description="">\n', '##FILTER=<ID=VQSRTrancheSNP99.60to99.80,Description="">\n', '##FILTER=<ID=VQSRTrancheSNP99.80to99.90,Description="">\n', '##FILTER=<ID=VQSRTrancheSNP99.90to99.95,Description="">\n', '##FILTER=<ID=VQSRTrancheSNP99.95to100.00+,Description="">\n', '##FILTER=<ID=VQSRTrancheSNP99.95to100.00,Description="">\n', '##INFO=<ID=END,Number=1,Type=Integer,Description="Stop position of ' 'the interval">\n', '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n', '##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">\n', '##FORMAT=<ID=DP,Number=1,Type=Integer,Description="Read depth of all ' 'passing filters reads.">\n', '##FORMAT=<ID=MIN_DP,Number=1,Type=Integer,Description="Minimum DP ' 'observed within the GVCF block.">\n', '##FORMAT=<ID=AD,Number=R,Type=Integer,Description="Read depth of all ' 'passing filters reads for each allele.">\n', '##FORMAT=<ID=VAF,Number=A,Type=Float,Description=\"Variant allele ' 'fractions.">\n', '##FORMAT=<ID=PL,Number=G,Type=Integer,Description="Genotype ' 'likelihoods, Phred encoded">\n', '##contig=<ID=chr1,length=248956422>\n', '##contig=<ID=chr2,length=242193529>\n', '##contig=<ID=chr3,length=198295559>\n', '##contig=<ID=chrX,length=156040895>\n', '#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tNA12878_18_99\n', 'chr1\t13613\t.\tT\tA\t39.88\tVQSRTrancheSNP99.90to99.95\t.\tGT:GQ:DP:AD:PL\t0/1:16:4:1,3:68,0,16\n', 'chr1\t13813\t.\tT\tG\t90.28\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:9:3:0,3:118,9,0\n', 'chr1\t13838\trs28428499\tC\tT\t62.74\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:6:2:0,2:90,6,0\n', 'chr1\t14397\trs756427959\tCTGT\tC\t37.73\tPASS\t.\tGT:GQ:DP:AD:PL\t0/1:75:5:3,2:75,0,152\n', 'chr1\t14522\t.\tG\tA\t49.77\tVQSRTrancheSNP99.60to99.80\t.\tGT:GQ:DP:AD:PL\t0/1:78:10:6,4:78,0,118\n' ] # pylint: enable=line-too-long with gfile.Open(out_fname, 'r') as f: self.assertEqual(f.readlines(), expected_vcf_content)