Example #1
0
 def setUp(self):
     tfrecord_file = test_utils.genomics_core_testdata(
         'test_features.gff.tfrecord')
     self.records = list(
         tfrecord.read_tfrecords(tfrecord_file, proto=gff_pb2.GffRecord))
     self.header = gff_pb2.GffHeader(
         sequence_regions=[ranges.make_range('ctg123', 0, 1497228)])
Example #2
0
def examples_to_variants(examples_path, max_records=None):
    """Yields Variant protos from the examples in examples_path.

  This function reads in tf.Examples produced by DeepVariant from examples_path,
  which may contain a sharded spec, sorts them, selects a representive example
  when there are multiple versions representing different alt_alleles, and
  yields the example_variant field from those examples.

  Args:
    examples_path: str. Path, or sharded spec, to labeled tf.Examples produced
      by DeepVariant in training mode.
    max_records: int or None. Maximum number of records to read, or None, to
      read all of the records.

  Yields:
    nucleus.protos.Variant protos in coordinate-sorted order.

  Raises:
    ValueError: if we find a Variant in any example that doesn't have genotypes.
  """
    examples = tfrecord.read_tfrecords(examples_path, max_records=max_records)
    variants = sorted(
        (tf_utils.example_variant(example) for example in examples),
        key=variant_utils.variant_range_tuple)

    for _, group in itertools.groupby(variants,
                                      variant_utils.variant_range_tuple):
        variant = next(group)
        if not variantcall_utils.has_genotypes(
                variant_utils.only_call(variant)):
            raise ValueError((
                'Variant {} does not have any genotypes. This tool only works with '
                'variants that have been labeled.').format(
                    variant_utils.variant_key(variant)))
        yield variant
Example #3
0
def _transform_call_variants_output_to_variants(input_sorted_tfrecord_path,
                                                qual_filter,
                                                multi_allelic_qual_filter,
                                                sample_name):
    """Yields Variant protos in sorted order from CallVariantsOutput protos.

  Variants present in the input TFRecord are converted to Variant protos, with
  the following filters applied: 1) variants are omitted if their quality is
  lower than the `qual_filter` threshold. 2) multi-allelic variants omit
  individual alleles whose qualities are lower than the
  `multi_allelic_qual_filter` threshold.

  Args:
    input_sorted_tfrecord_path: str. TFRecord format file containing sorted
      CallVariantsOutput protos.
    qual_filter: double. The qual value below which to filter variants.
    multi_allelic_qual_filter: double. The qual value below which to filter
      multi-allelic variants.
    sample_name: str. Sample name to write to VCF file.

  Yields:
    Variant protos in sorted order representing the CallVariantsOutput calls.
  """
    for _, group in itertools.groupby(
            tfrecord.read_tfrecords(input_sorted_tfrecord_path,
                                    proto=deepvariant_pb2.CallVariantsOutput),
            lambda x: variant_utils.variant_range(x.variant)):
        outputs = _sort_grouped_variants(group)
        canonical_variant, predictions = merge_predictions(
            outputs, multi_allelic_qual_filter)
        variant = add_call_to_variant(canonical_variant,
                                      predictions,
                                      qual_filter=qual_filter,
                                      sample_name=sample_name)
        yield variant
Example #4
0
  def test_make_examples_training_end2end_with_alt_aligned_pileup(
      self, alt_align, expected_shape):
    region = ranges.parse_literal('chr20:10,000,000-10,010,000')
    FLAGS.regions = [ranges.to_literal(region)]
    FLAGS.ref = testdata.CHR20_FASTA
    FLAGS.reads = testdata.CHR20_BAM
    FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord'))
    FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord'))
    FLAGS.partition_size = 1000
    FLAGS.mode = 'training'
    FLAGS.gvcf_gq_binsize = 5
    FLAGS.alt_aligned_pileup = alt_align  # This is the only input change.
    FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
    FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
    options = make_examples.default_options(add_flags=True)
    # Run make_examples with the flags above.
    make_examples_core.make_examples_runner(options)

    # Check the output for shape and against the golden file.
    if alt_align == 'rows':
      golden_file = _sharded(testdata.ALT_ALIGNED_ROWS_EXAMPLES)
    elif alt_align == 'diff_channels':
      golden_file = _sharded(testdata.ALT_ALIGNED_DIFF_CHANNELS_EXAMPLES)
    else:
      raise ValueError("Golden data doesn't exist for this alt_align option: "
                       '{}'.format(alt_align))
    # Verify that the variants in the examples are all good.
    examples = self.verify_examples(
        FLAGS.examples, region, options, verify_labels=True)
    self.assertDeepVariantExamplesEqual(
        examples, list(tfrecord.read_tfrecords(golden_file)))
    # Pileup image should have 3 rows of height 100, so resulting height is 300.
    self.assertEqual(decode_example(examples[0])['image/shape'], expected_shape)
Example #5
0
  def test_make_examples_end2end_vcf_candidate_importer(self, mode):
    FLAGS.variant_caller = 'vcf_candidate_importer'
    FLAGS.ref = testdata.CHR20_FASTA
    FLAGS.reads = testdata.CHR20_BAM
    FLAGS.candidates = test_utils.test_tmpfile(
        _sharded('vcf_candidate_importer.{}.tfrecord'.format(mode)))
    FLAGS.examples = test_utils.test_tmpfile(
        _sharded('vcf_candidate_importer.examples.{}.tfrecord'.format(mode)))
    FLAGS.mode = mode

    if mode == 'calling':
      golden_file = _sharded(
          testdata.GOLDEN_VCF_CANDIDATE_IMPORTER_CALLING_EXAMPLES)
      FLAGS.proposed_variants = testdata.VCF_CANDIDATE_IMPORTER_VARIANTS
      # Adding the following flags to match how the testdata was created.
      FLAGS.regions = 'chr20:59,777,000-60,000,000'
      FLAGS.realign_reads = False
    else:
      golden_file = _sharded(
          testdata.GOLDEN_VCF_CANDIDATE_IMPORTER_TRAINING_EXAMPLES)
      FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
    options = make_examples.default_options(add_flags=True)
    make_examples_core.make_examples_runner(options)
    # Verify that the variants in the examples are all good.
    examples = self.verify_examples(
        FLAGS.examples, None, options, verify_labels=mode == 'training')
    self.assertDeepVariantExamplesEqual(
        examples, list(tfrecord.read_tfrecords(golden_file)))
    self.assertEqual(
        decode_example(examples[0])['image/shape'],
        [100, 221, dv_constants.PILEUP_NUM_CHANNELS])
Example #6
0
  def test_make_examples_with_allele_frequency(self, mode):
    FLAGS.mode = 'calling'
    FLAGS.ref = testdata.GRCH38_FASTA
    FLAGS.reads = testdata.GRCH38_CHR20_AND_21_BAM
    num_shards = 1
    FLAGS.examples = test_utils.test_tmpfile(
        _sharded('examples.tfrecord', num_shards))
    region = ranges.parse_literal('chr20:61001-62000')
    FLAGS.use_allele_frequency = True
    FLAGS.regions = [ranges.to_literal(region)]
    if mode == 'one vcf':
      FLAGS.population_vcfs = testdata.AF_VCF_CHR20_AND_21
    elif mode == 'two vcfs':
      FLAGS.population_vcfs = ' '.join(
          [testdata.AF_VCF_CHR20, testdata.AF_VCF_CHR21])
    else:
      raise ValueError('Invalid mode for parameterized test.')
    options = make_examples.default_options(add_flags=True)
    # Run make_examples with the flags above.
    make_examples_core.make_examples_runner(options)

    # Verify that the variants in the examples are all good.
    examples = self.verify_examples(
        FLAGS.examples, region, options, verify_labels=False)

    # Pileup images should have one extra channel.
    self.assertEqual([100, 221, dv_constants.PILEUP_NUM_CHANNELS + 1],
                     decode_example(examples[0])['image/shape'])

    # Test there is something in the added channel.
    # Values capture whether each loci has been seen in the observed examples.
    population_matched_loci = {
        'chr20:61539_A': False,
        'chr20:61634_G': False,
        'chr20:61644_G': False
    }

    for example in examples:
      locus_id = vis.locus_id_from_variant(vis.variant_from_example(example))
      if locus_id in population_matched_loci.keys():
        channels = vis.channels_from_example(example)
        self.assertGreater(
            np.sum(channels[dv_constants.PILEUP_NUM_CHANNELS]),
            0,
            msg='There should be '
            'something in the %s-th channel for variant '
            '%s' % (dv_constants.PILEUP_NUM_CHANNELS + 1, locus_id))
        population_matched_loci[locus_id] = True
    self.assertTrue(
        all(population_matched_loci.values()),
        msg='Check that all '
        '3 sample loci appeared in the examples.')

    # Check against the golden file (same for both modes).
    golden_file = _sharded(testdata.GOLDEN_ALLELE_FREQUENCY_EXAMPLES)
    examples_from_golden = list(tfrecord.read_tfrecords(golden_file))
    self.assertDeepVariantExamplesEqual(examples_from_golden, examples)
Example #7
0
 def test_call_end2end_with_empty_shards(self):
     # Get only up to 10 examples.
     examples = list(
         tfrecord.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES,
                                 max_records=10))
     # Write to 15 shards, which means there will be multiple empty shards.
     source_path = test_utils.test_tmpfile('sharded@{}'.format(15))
     tfrecord.write_tfrecords(examples, source_path)
     self.assertCallVariantsEmitsNRecordsForRandomGuess(
         source_path, len(examples))
Example #8
0
def make_golden_dataset(compressed_inputs=False):
  if compressed_inputs:
    source_path = test_utils.test_tmpfile(
        'golden.postprocess_single_site_input.tfrecord.gz')
    tfrecord.write_tfrecords(
        tfrecord.read_tfrecords(
            testdata.GOLDEN_POSTPROCESS_INPUT,
            proto=deepvariant_pb2.CallVariantsOutput), source_path)
  else:
    source_path = testdata.GOLDEN_POSTPROCESS_INPUT
  return source_path
Example #9
0
 def test_call_end2end_empty_first_shard(self):
   # Get only up to 10 examples.
   examples = list(
       tfrecord.read_tfrecords(
           testdata.GOLDEN_CALLING_EXAMPLES, max_records=10))
   empty_first_file = test_utils.test_tmpfile('empty_1st_shard-00000-of-00002')
   tfrecord.write_tfrecords([], empty_first_file)
   second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002')
   tfrecord.write_tfrecords(examples, second_file)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(
       test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))
Example #10
0
  def test_read_tfrecords_max_records(self, filename, max_records):
    protos, path = self.write_test_protos(filename)

    # Create our generator of records from read_tfrecords.
    if max_records is None:
      expected_n = len(protos)
    else:
      expected_n = min(max_records, len(protos))
    actual = tfrecord.read_tfrecords(
        path, reference_pb2.ContigInfo, max_records=max_records)
    self.assertLen(list(actual), expected_n)
Example #11
0
    def _call_end2end_helper(self, examples_path, model, shard_inputs):
        examples = list(tfrecord.read_tfrecords(examples_path))

        if shard_inputs:
            # Create a sharded version of our golden examples.
            source_path = test_utils.test_tmpfile('sharded@{}'.format(3))
            tfrecord.write_tfrecords(examples, source_path)
        else:
            source_path = examples_path

        # If we point the test at a headless server, it will often be 2x2,
        # which has 8 replicas.  Otherwise a smaller batch size is fine.
        if FLAGS.use_tpu:
            batch_size = 8
        else:
            batch_size = 4

        if model.name == 'random_guess':
            # For the random guess model we can run everything.
            max_batches = None
        else:
            # For all other models we only run a single batch for inference.
            max_batches = 1

        outfile = test_utils.test_tmpfile('call_variants.tfrecord')
        call_variants.call_variants(
            examples_filename=source_path,
            checkpoint_path=_LEAVE_MODEL_UNINITIALIZED,
            model=model,
            output_file=outfile,
            batch_size=batch_size,
            max_batches=max_batches,
            master='',
            use_tpu=FLAGS.use_tpu,
        )

        call_variants_outputs = list(
            tfrecord.read_tfrecords(outfile,
                                    deepvariant_pb2.CallVariantsOutput))

        return call_variants_outputs, examples, batch_size, max_batches
Example #12
0
 def test_call_variants_with_no_shape(self, model):
   # Read one good record from a valid file.
   example = next(tfrecord.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
   # Remove image/shape.
   del example.features.feature['image/shape']
   source_path = test_utils.test_tmpfile('make_examples_out_noshape.tfrecord')
   tfrecord.write_tfrecords([example], source_path)
   with six.assertRaisesRegex(
       self, ValueError,
       'Invalid image/shape: we expect to find an image/shape '
       'field with length 3.'):
     ds = call_variants.prepare_inputs(source_path)
     _ = list(_get_infer_batches(ds, model=model, batch_size=1))
Example #13
0
  def test_reading_sharded_input_with_empty_shards_does_not_crash(self):
    valid_variants = tfrecord.read_tfrecords(
        testdata.GOLDEN_POSTPROCESS_INPUT,
        proto=deepvariant_pb2.CallVariantsOutput)
    empty_shard_one = test_utils.test_tmpfile(
        'reading_empty_shard.tfrecord-00000-of-00002')
    none_empty_shard_two = test_utils.test_tmpfile(
        'reading_empty_shard.tfrecord-00001-of-00002')
    tfrecord.write_tfrecords([], empty_shard_one)
    tfrecord.write_tfrecords(valid_variants, none_empty_shard_two)
    FLAGS.infile = test_utils.test_tmpfile('reading_empty_shard.tfrecord@2')
    FLAGS.ref = testdata.CHR20_FASTA
    FLAGS.outfile = test_utils.test_tmpfile('calls_reading_empty_shard.vcf')

    postprocess_variants.main(['postprocess_variants.py'])
Example #14
0
    def test_make_read_writer_tfrecords(self):
        outfile = test_utils.test_tmpfile('test.tfrecord')
        writer = sam.SamWriter(outfile, header=self.header)

        # Test that the writer is a context manager and that we can write a read to
        # it.
        with writer:
            writer.write(self.read1)
            writer.write(self.read2)

        # Our output should have exactly one read in it.
        self.assertEqual([self.read1, self.read2],
                         list(
                             tfrecord.read_tfrecords(outfile,
                                                     proto=reads_pb2.Read)))
Example #15
0
  def test_read_write_tfrecords(self, filename):
    protos, path = self.write_test_protos(filename)

    # Create our generator of records from read_tfrecords.
    reader = tfrecord.read_tfrecords(path, reference_pb2.ContigInfo)

    # Make sure it's actually a generator.
    self.assertEqual(type(reader), types.GeneratorType)

    # Check the round-trip contents.
    if '@' in filename:
      # Sharded outputs are striped across shards, so order isn't preserved.
      self.assertCountEqual(protos, reader)
    else:
      self.assertEqual(protos, list(reader))
Example #16
0
def _transform_call_variants_output_to_variants(input_sorted_tfrecord_path,
                                                qual_filter,
                                                multi_allelic_qual_filter,
                                                sample_name, group_variants,
                                                use_multiallelic_model):
  """Yields Variant protos in sorted order from CallVariantsOutput protos.

  Variants present in the input TFRecord are converted to Variant protos, with
  the following filters applied: 1) variants are omitted if their quality is
  lower than the `qual_filter` threshold. 2) multi-allelic variants omit
  individual alleles whose qualities are lower than the
  `multi_allelic_qual_filter` threshold.

  Args:
    input_sorted_tfrecord_path: str. TFRecord format file containing sorted
      CallVariantsOutput protos.
    qual_filter: double. The qual value below which to filter variants.
    multi_allelic_qual_filter: double. The qual value below which to filter
      multi-allelic variants.
    sample_name: str. Sample name to write to VCF file.
    group_variants: bool. If true, group variants that have same start and end
      position.
    use_multiallelic_model: if True, use a specialized model for genotype
      resolution of multiallelic cases with two alts.

  Yields:
    Variant protos in sorted order representing the CallVariantsOutput calls.
  """
  multiallelic_model = get_multiallelic_model(
      use_multiallelic_model=use_multiallelic_model)
  group_fn = None
  if group_variants:
    group_fn = lambda x: variant_utils.variant_range(x.variant)
  for _, group in itertools.groupby(
      tfrecord.read_tfrecords(
          input_sorted_tfrecord_path, proto=deepvariant_pb2.CallVariantsOutput),
      group_fn):
    outputs = _sort_grouped_variants(group)
    canonical_variant, predictions = merge_predictions(
        outputs,
        multi_allelic_qual_filter,
        multiallelic_model=multiallelic_model)
    variant = add_call_to_variant(
        canonical_variant,
        predictions,
        qual_filter=qual_filter,
        sample_name=sample_name)
    yield variant
Example #17
0
  def assertTfDataSetExamplesMatchExpected(self,
                                           input_fn,
                                           expected_dataset,
                                           use_tpu=False,
                                           workaround_list_files=False):
    # Note that we use input_fn to get an iterator, while we use
    # expected_dataset to get a filename, even though they are the same
    # type (DeepVariantInput), and may even be the same object.
    with tf.compat.v1.Session() as sess:
      params = {'batch_size': 1}
      batch_feed = tf.compat.v1.data.make_one_shot_iterator(
          input_fn(params)).get_next()

      sess.run(tf.compat.v1.global_variables_initializer())
      sess.run(tf.compat.v1.local_variables_initializer())
      seen = []
      while True:
        try:
          features, _ = sess.run(batch_feed)
        except tf.errors.OutOfRangeError:
          break
        locus = features['locus'][0]
        if use_tpu:
          locus = tf_utils.int_tensor_to_string(locus)
        # NB, this looks like: array(['chr20:10001019-10001019'], dtype=object)
        seen.append(locus)

    if workaround_list_files:
      # This really only works for loci, because those are string valued and
      # are expected to show up in sorted order.  For arbitrary data that's
      # not true.  In prod we have the version of tf that lets us turn off
      # shuffling so this path is skipped, but kokoro hits this.
      seen = sorted(seen)

    expected_loci = [
        example.features.feature['locus'].bytes_list.value[0]
        for example in tfrecord.read_tfrecords(expected_dataset.input_file_spec)
    ]
    self.assertLen(expected_loci, expected_dataset.num_examples)
    if seen != expected_loci:
      print('\n\nlen expected seen', len(expected_loci), len(seen))
      print('\n\nexpected=', expected_loci)
      print('\n\nseen=', seen)
    self.assertEqual(expected_loci, seen)
    # Note that this expected shape comes from the golden dataset. If the data
    # is remade in the future, the values might need to be modified accordingly.
    self.assertEqual(dv_constants.PILEUP_DEFAULT_DIMS,
                     expected_dataset.tensor_shape)
Example #18
0
def make_golden_dataset(compressed_inputs=False,
                        mode=tf.estimator.ModeKeys.EVAL,
                        use_tpu=False):
  if compressed_inputs:
    source_path = test_utils.test_tmpfile('make_golden_dataset.tfrecord.gz')
    tfrecord.write_tfrecords(
        tfrecord.read_tfrecords(testdata.GOLDEN_TRAINING_EXAMPLES), source_path)
  else:
    source_path = testdata.GOLDEN_TRAINING_EXAMPLES
  return data_providers.get_input_fn_from_filespec(
      input_file_spec=source_path,
      num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES,
      name='labeled_golden',
      mode=mode,
      tensor_shape=None,
      use_tpu=use_tpu)
    def test_writing_canned_records(self):
        """Tests writing all the records that are 'canned' in our tfrecord file."""
        # This file is in TFRecord format.
        tfrecord_file = test_utils.genomics_core_testdata(
            'test_features.gff.tfrecord')
        writer_options = gff_pb2.GffWriterOptions()
        gff_records = list(
            tfrecord.read_tfrecords(tfrecord_file, proto=gff_pb2.GffRecord))
        out_fname = test_utils.test_tmpfile('output.gff')
        with gff_writer.GffWriter.to_file(out_fname, self.header,
                                          writer_options) as writer:
            for record in gff_records:
                writer.write(record)

        with open(out_fname) as f:
            self.assertEqual(f.readlines(), self.expected_gff_content)
Example #20
0
  def assertCallVariantsEmitsNRecordsForInceptionV3(self, filename,
                                                    num_examples):
    outfile = test_utils.test_tmpfile('inception_v3.call_variants.tfrecord')
    model = modeling.get_model('inception_v3')
    checkpoint_path = _LEAVE_MODEL_UNINITIALIZED

    call_variants.call_variants(
        examples_filename=filename,
        checkpoint_path=checkpoint_path,
        model=model,
        output_file=outfile,
        batch_size=4,
        max_batches=None)
    call_variants_outputs = list(
        tfrecord.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput))
    # Check that we have the right number of output protos.
    self.assertEqual(len(call_variants_outputs), num_examples)
Example #21
0
 def assertCallVariantsEmitsNRecordsForRandomGuess(self, filename,
                                                   num_examples):
     checkpoint_path = _LEAVE_MODEL_UNINITIALIZED
     outfile = test_utils.test_tmpfile('call_variants.tfrecord')
     model = modeling.get_model('random_guess')
     call_variants.call_variants(examples_filename=filename,
                                 checkpoint_path=checkpoint_path,
                                 model=model,
                                 output_file=outfile,
                                 batch_size=4,
                                 max_batches=None,
                                 master='',
                                 use_tpu=FLAGS.use_tpu)
     call_variants_outputs = list(
         tfrecord.read_tfrecords(outfile,
                                 deepvariant_pb2.CallVariantsOutput))
     # Check that we have the right number of output protos.
     self.assertEqual(len(call_variants_outputs), num_examples)
    def test_writing_canned_records(self):
        """Tests writing all the variants that are 'canned' in our tfrecord file."""
        # This file is in TFRecord format.
        tfrecord_file = test_utils.genomics_core_testdata(
            'test_reads.fastq.tfrecord')

        writer_options = fastq_pb2.FastqWriterOptions()
        fastq_records = list(
            tfrecord.read_tfrecords(tfrecord_file,
                                    proto=fastq_pb2.FastqRecord))
        out_fname = test_utils.test_tmpfile('output.fastq')
        with fastq_writer.FastqWriter.to_file(out_fname,
                                              writer_options) as writer:
            for record in fastq_records:
                writer.write(record)

        with gfile.Open(out_fname, 'r') as f:
            self.assertEqual(f.readlines(), self.expected_fastq_content)
    def test_writing_canned_records(self):
        """Tests writing all the records that are 'canned' in our tfrecord file."""
        # This file is in TFRecord format.
        tfrecord_file = test_utils.genomics_core_testdata(
            'test_regions.bed.tfrecord')

        header = bed_pb2.BedHeader(num_fields=12)
        writer_options = bed_pb2.BedWriterOptions()
        bed_records = list(
            tfrecord.read_tfrecords(tfrecord_file, proto=bed_pb2.BedRecord))
        out_fname = test_utils.test_tmpfile('output.bed')
        with bed_writer.BedWriter.to_file(out_fname, header,
                                          writer_options) as writer:
            for record in bed_records:
                writer.write(record)

        with gfile.Open(out_fname, 'r') as f:
            self.assertEqual(f.readlines(), self.expected_bed_content)
Example #24
0
  def test_call_variants_with_invalid_format(self, model, bad_format):
    # Read one good record from a valid file.
    example = next(tfrecord.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
    # Overwrite the image/format field to be an invalid value
    # (anything but 'raw').
    example.features.feature['image/format'].bytes_list.value[0] = bad_format
    source_path = test_utils.test_tmpfile('make_examples_output.tfrecord')
    tfrecord.write_tfrecords([example], source_path)
    outfile = test_utils.test_tmpfile('call_variants_invalid_format.tfrecord')

    with self.assertRaises(ValueError):
      call_variants.call_variants(
          examples_filename=source_path,
          checkpoint_path=_LEAVE_MODEL_UNINITIALIZED,
          model=model,
          output_file=outfile,
          batch_size=1,
          max_batches=1,
          use_tpu=FLAGS.use_tpu)
Example #25
0
  def test_make_examples_with_variant_selection(self,
                                                select_types,
                                                expected_count,
                                                keep_legacy_behavior=False):
    if select_types is not None:
      FLAGS.select_variant_types = select_types
    region = ranges.parse_literal('chr20:10,000,000-10,010,000')
    FLAGS.regions = [ranges.to_literal(region)]
    FLAGS.ref = testdata.CHR20_FASTA
    FLAGS.reads = testdata.CHR20_BAM
    FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord'))
    FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord'))
    FLAGS.partition_size = 1000
    FLAGS.mode = 'calling'
    FLAGS.keep_legacy_allele_counter_behavior = keep_legacy_behavior
    options = make_examples.default_options(add_flags=True)
    make_examples_core.make_examples_runner(options)

    candidates = list(tfrecord.read_tfrecords(FLAGS.candidates))
    self.assertLen(candidates, expected_count)
Example #26
0
  def test_reading_sharded_dataset(self, compressed_inputs, use_tpu):
    golden_dataset = make_golden_dataset(compressed_inputs, use_tpu=use_tpu)
    n_shards = 3
    sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards))
    tfrecord.write_tfrecords(
        tfrecord.read_tfrecords(golden_dataset.input_file_spec), sharded_path)

    config_file = _test_dataset_config(
        'test_sharded.pbtxt',
        name='sharded_test',
        tfrecord_path=sharded_path,
        num_examples=golden_dataset.num_examples)

    self.assertTfDataSetExamplesMatchExpected(
        data_providers.get_input_fn_from_dataset(
            config_file, mode=tf.estimator.ModeKeys.EVAL),
        golden_dataset,
        # workaround_list_files is needed because wildcards, and so sharded
        # files, are nondeterministicly ordered (for now).
        workaround_list_files=True,
    )
Example #27
0
  def verify_examples(self, examples_filename, region, options, verify_labels):
    # Do some simple structural checks on the tf.Examples in the file.
    expected_features = [
        'variant/encoded', 'locus', 'image/format', 'image/encoded',
        'alt_allele_indices/encoded'
    ]
    if verify_labels:
      expected_features += ['label']

    examples = list(tfrecord.read_tfrecords(examples_filename))
    for example in examples:
      for label_feature in expected_features:
        self.assertIn(label_feature, example.features.feature)
      # pylint: disable=g-explicit-length-test
      self.assertNotEmpty(tf_utils.example_alt_alleles_indices(example))

    # Check that the variants in the examples are good.
    variants = [tf_utils.example_variant(x) for x in examples]
    self.verify_variants(variants, region, options, is_gvcf=False)

    return examples
Example #28
0
 def test_make_examples_training_end2end_with_customized_classes_labeler(self):
   FLAGS.labeler_algorithm = 'customized_classes_labeler'
   FLAGS.customized_classes_labeler_classes_list = 'ref,class1,class2'
   FLAGS.customized_classes_labeler_info_field_name = 'type'
   region = ranges.parse_literal('chr20:10,000,000-10,004,000')
   FLAGS.regions = [ranges.to_literal(region)]
   FLAGS.ref = testdata.CHR20_FASTA
   FLAGS.reads = testdata.CHR20_BAM
   FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord'))
   FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord'))
   FLAGS.partition_size = 1000
   FLAGS.mode = 'training'
   FLAGS.gvcf_gq_binsize = 5
   FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF_WITH_TYPES
   FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
   options = make_examples.default_options(add_flags=True)
   make_examples_core.make_examples_runner(options)
   golden_file = _sharded(testdata.CUSTOMIZED_CLASSES_GOLDEN_TRAINING_EXAMPLES)
   # Verify that the variants in the examples are all good.
   examples = self.verify_examples(
       FLAGS.examples, region, options, verify_labels=True)
   self.assertDeepVariantExamplesEqual(
       examples, list(tfrecord.read_tfrecords(golden_file)))
Example #29
0
def get_one_example_from_examples_path(source, proto=None):
    """Get the first record from `source`.

  Args:
    source: str. A pattern or a comma-separated list of patterns that represent
      file names.
    proto: A proto class. proto.FromString() will be called on each serialized
      record in path to parse it.

  Returns:
    The first record, or None.
  """
    files = sharded_file_utils.glob_list_sharded_file_patterns(source)
    if not files:
        raise ValueError(
            'Cannot find matching files with the pattern "{}"'.format(source))
    for f in files:
        try:
            return tfrecord.read_tfrecords(f, proto=proto).next()
        except StopIteration:
            # Getting a StopIteration from one next() means source_path is empty.
            # Move on to the next one to try to get one example.
            pass
    return None
  def test_writing_canned_variants(self):
    """Tests writing all the variants that are 'canned' in our tfrecord file."""
    # This file is in the TF record format
    tfrecord_file = test_utils.genomics_core_testdata(
        'test_samples.vcf.golden.tfrecord')

    writer_options = variants_pb2.VcfWriterOptions()
    header = variants_pb2.VcfHeader(
        contigs=[
            reference_pb2.ContigInfo(name='chr1', n_bases=248956422),
            reference_pb2.ContigInfo(name='chr2', n_bases=242193529),
            reference_pb2.ContigInfo(name='chr3', n_bases=198295559),
            reference_pb2.ContigInfo(name='chrX', n_bases=156040895)
        ],
        sample_names=['NA12878_18_99'],
        filters=[
            variants_pb2.VcfFilterInfo(
                id='PASS', description='All filters passed'),
            variants_pb2.VcfFilterInfo(id='LowQual', description=''),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL95.00to96.00'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL96.00to97.00'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL97.00to99.00'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.00to99.50'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.50to99.90'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.90to99.95'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.95to100.00+'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.95to100.00'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.50to99.60'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.60to99.80'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.80to99.90'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.90to99.95'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00+'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00'),
        ],
        infos=[
            variants_pb2.VcfInfo(
                id='END',
                number='1',
                type='Integer',
                description='Stop position of the interval')
        ],
        formats=[
            variants_pb2.VcfFormatInfo(
                id='GT', number='1', type='String', description='Genotype'),
            variants_pb2.VcfFormatInfo(
                id='GQ',
                number='1',
                type='Integer',
                description='Genotype Quality'),
            variants_pb2.VcfFormatInfo(
                id='DP',
                number='1',
                type='Integer',
                description='Read depth of all passing filters reads.'),
            variants_pb2.VcfFormatInfo(
                id='MIN_DP',
                number='1',
                type='Integer',
                description='Minimum DP observed within the GVCF block.'),
            variants_pb2.VcfFormatInfo(
                id='AD',
                number='R',
                type='Integer',
                description=
                'Read depth of all passing filters reads for each allele.'),
            variants_pb2.VcfFormatInfo(
                id='VAF',
                number='A',
                type='Float',
                description='Variant allele fractions.'),
            variants_pb2.VcfFormatInfo(
                id='PL',
                number='G',
                type='Integer',
                description='Genotype likelihoods, Phred encoded'),
        ],
    )
    variant_records = list(
        tfrecord.read_tfrecords(tfrecord_file, proto=variants_pb2.Variant))
    out_fname = test_utils.test_tmpfile('output.vcf')
    with vcf_writer.VcfWriter.to_file(out_fname, header,
                                      writer_options) as writer:
      for record in variant_records[:5]:
        writer.write(record)

    # Check: are the variants written as expected?
    # pylint: disable=line-too-long
    expected_vcf_content = [
        '##fileformat=VCFv4.2\n',
        '##FILTER=<ID=PASS,Description="All filters passed">\n',
        '##FILTER=<ID=LowQual,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL95.00to96.00,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL96.00to97.00,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL97.00to99.00,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.00to99.50,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.50to99.90,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.90to99.95,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00+,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.50to99.60,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.60to99.80,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.80to99.90,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.90to99.95,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.95to100.00+,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.95to100.00,Description="">\n',
        '##INFO=<ID=END,Number=1,Type=Integer,Description="Stop position of '
        'the interval">\n',
        '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n',
        '##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">\n',
        '##FORMAT=<ID=DP,Number=1,Type=Integer,Description="Read depth of all '
        'passing filters reads.">\n',
        '##FORMAT=<ID=MIN_DP,Number=1,Type=Integer,Description="Minimum DP '
        'observed within the GVCF block.">\n',
        '##FORMAT=<ID=AD,Number=R,Type=Integer,Description="Read depth of all '
        'passing filters reads for each allele.">\n',
        '##FORMAT=<ID=VAF,Number=A,Type=Float,Description=\"Variant allele '
        'fractions.">\n',
        '##FORMAT=<ID=PL,Number=G,Type=Integer,Description="Genotype '
        'likelihoods, Phred encoded">\n',
        '##contig=<ID=chr1,length=248956422>\n',
        '##contig=<ID=chr2,length=242193529>\n',
        '##contig=<ID=chr3,length=198295559>\n',
        '##contig=<ID=chrX,length=156040895>\n',
        '#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tNA12878_18_99\n',
        'chr1\t13613\t.\tT\tA\t39.88\tVQSRTrancheSNP99.90to99.95\t.\tGT:GQ:DP:AD:PL\t0/1:16:4:1,3:68,0,16\n',
        'chr1\t13813\t.\tT\tG\t90.28\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:9:3:0,3:118,9,0\n',
        'chr1\t13838\trs28428499\tC\tT\t62.74\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:6:2:0,2:90,6,0\n',
        'chr1\t14397\trs756427959\tCTGT\tC\t37.73\tPASS\t.\tGT:GQ:DP:AD:PL\t0/1:75:5:3,2:75,0,152\n',
        'chr1\t14522\t.\tG\tA\t49.77\tVQSRTrancheSNP99.60to99.80\t.\tGT:GQ:DP:AD:PL\t0/1:78:10:6,4:78,0,118\n'
    ]
    # pylint: enable=line-too-long

    with gfile.Open(out_fname, 'r') as f:
      self.assertEqual(f.readlines(), expected_vcf_content)