Beispiel #1
0
 def setUp(self):
     tfrecord_file = test_utils.genomics_core_testdata(
         'test_features.gff.tfrecord')
     self.records = list(
         io_utils.read_tfrecords(tfrecord_file, proto=gff_pb2.GffRecord))
     self.header = gff_pb2.GffHeader(
         sequence_regions=[ranges.make_range('ctg123', 0, 1497228)])
Beispiel #2
0
def _transform_call_variants_output_to_variants(input_sorted_tfrecord_path,
                                                qual_filter,
                                                multi_allelic_qual_filter,
                                                sample_name):
    """Yields Variant protos in sorted order from CallVariantsOutput protos.

  Variants present in the input TFRecord are converted to Variant protos, with
  the following filters applied: 1) variants are omitted if their quality is
  lower than the `qual_filter` threshold. 2) multi-allelic variants omit
  individual alleles whose qualities are lower than the
  `multi_allelic_qual_filter` threshold.

  Args:
    input_sorted_tfrecord_path: str. TFRecord format file containing sorted
      CallVariantsOutput protos.
    qual_filter: double. The qual value below which to filter variants.
    multi_allelic_qual_filter: double. The qual value below which to filter
      multi-allelic variants.
    sample_name: str. Sample name to write to VCF file.

  Yields:
    Variant protos in sorted order representing the CallVariantsOutput calls.
  """
    for _, group in itertools.groupby(
            io_utils.read_tfrecords(input_sorted_tfrecord_path,
                                    proto=deepvariant_pb2.CallVariantsOutput),
            lambda x: variant_utils.variant_range(x.variant)):
        outputs = _sort_grouped_variants(group)
        canonical_variant, predictions = merge_predictions(
            outputs, multi_allelic_qual_filter)
        variant = add_call_to_variant(canonical_variant,
                                      predictions,
                                      qual_filter=qual_filter,
                                      sample_name=sample_name)
        yield variant
def examples_to_variants(examples_path, max_records=None):
  """Yields Variant protos from the examples in examples_path.

  This function reads in tf.Examples produced by DeepVariant from examples_path,
  which may contain a sharded spec, sorts them, selects a representive example
  when there are multiple versions representing different alt_alleles, and
  yields the example_variant field from those examples.

  Args:
    examples_path: str. Path, or sharded spec, to labeled tf.Examples produced
      by DeepVariant in training mode.
    max_records: int or None. Maximum number of records to read, or None, to
      read all of the records.

  Yields:
    nucleus.protos.Variant protos in coordinate-sorted order.

  Raises:
    ValueError: if we find a Variant in any example that doesn't have genotypes.
  """
  examples = io_utils.read_tfrecords(examples_path, max_records=max_records)
  variants = sorted(
      (tf_utils.example_variant(example) for example in examples),
      key=variant_utils.variant_range_tuple)

  for _, group in itertools.groupby(variants,
                                    variant_utils.variant_range_tuple):
    variant = next(group)
    if not variantcall_utils.has_genotypes(variant_utils.only_call(variant)):
      raise ValueError((
          'Variant {} does not have any genotypes. This tool only works with '
          'variants that have been labeled.').format(
              variant_utils.variant_key(variant)))
    yield variant
def _transform_call_variants_output_to_variants(
    input_sorted_tfrecord_path, qual_filter, multi_allelic_qual_filter,
    sample_name):
  """Yields Variant protos in sorted order from CallVariantsOutput protos.

  Variants present in the input TFRecord are converted to Variant protos, with
  the following filters applied: 1) variants are omitted if their quality is
  lower than the `qual_filter` threshold. 2) multi-allelic variants omit
  individual alleles whose qualities are lower than the
  `multi_allelic_qual_filter` threshold.

  Args:
    input_sorted_tfrecord_path: str. TFRecord format file containing sorted
      CallVariantsOutput protos.
    qual_filter: double. The qual value below which to filter variants.
    multi_allelic_qual_filter: double. The qual value below which to filter
      multi-allelic variants.
    sample_name: str. Sample name to write to VCF file.

  Yields:
    Variant protos in sorted order representing the CallVariantsOutput calls.
  """
  for _, group in itertools.groupby(
      io_utils.read_tfrecords(
          input_sorted_tfrecord_path, proto=deepvariant_pb2.CallVariantsOutput),
      lambda x: variant_utils.variant_range(x.variant)):
    outputs = list(group)
    canonical_variant, predictions = merge_predictions(
        outputs, multi_allelic_qual_filter)
    variant = add_call_to_variant(
        canonical_variant,
        predictions,
        qual_filter=qual_filter,
        sample_name=sample_name)
    yield variant
Beispiel #5
0
    def assertDataSetExamplesMatchExpected(self, dataset, expected_dataset):
        with tf.Session() as sess:
            provider = slim.dataset_data_provider.DatasetDataProvider(
                expected_dataset.get_slim_dataset(),
                shuffle=False,
                reader_kwargs={
                    'options':
                    io_utils.make_tfrecord_options(expected_dataset.source)
                })
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord, sess=sess)
            image, label, locus = provider.get(['image', 'label', 'locus'])
            seen = [
                sess.run([image, label, locus])[2]
                for _ in range(expected_dataset.num_examples)
            ]
            coord.request_stop()
            coord.join(threads)

        expected_loci = [
            example.features.feature['locus'].bytes_list.value[0]
            for example in io_utils.read_tfrecords(expected_dataset.source)
        ]
        self.assertEqual(len(expected_loci), expected_dataset.num_examples)
        self.assertEqual(expected_loci, seen)
        # Note that this expected shape comes from the golden dataset. If the data
        # is remade in the future, the values might need to be modified accordingly.
        self.assertEqual([100, 221, pileup_image.DEFAULT_NUM_CHANNEL],
                         expected_dataset.tensor_shape)
  def assertDataSetExamplesMatchExpected(self, dataset, expected_dataset):
    with tf.Session() as sess:
      provider = slim.dataset_data_provider.DatasetDataProvider(
          expected_dataset.get_slim_dataset(),
          shuffle=False,
          reader_kwargs={
              'options': io_utils.make_tfrecord_options(expected_dataset.source)
          })
      sess.run(tf.global_variables_initializer())
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(coord=coord, sess=sess)
      image, label, locus = provider.get(['image', 'label', 'locus'])
      seen = [
          sess.run([image, label, locus])[2]
          for _ in range(expected_dataset.num_examples)
      ]
      coord.request_stop()
      coord.join(threads)

    expected_loci = [
        example.features.feature['locus'].bytes_list.value[0]
        for example in io_utils.read_tfrecords(expected_dataset.source)
    ]
    self.assertEqual(len(expected_loci), expected_dataset.num_examples)
    self.assertEqual(expected_loci, seen)
    # Note that this expected shape comes from the golden dataset. If the data
    # is remade in the future, the values might need to be modified accordingly.
    self.assertEqual([100, 221, pileup_image.DEFAULT_NUM_CHANNEL],
                     expected_dataset.tensor_shape)
Beispiel #7
0
def _get_one_example_from_examples_path(source):
    """Reads one record from source."""
    # redacted
    # io_utils.read_tfrecord can read wildcard file patterns.
    # The source can be a comma-separated list.
    source_paths = source.split(',')
    for source_path in source_paths:
        files = tf.gfile.Glob(
            io_utils.NormalizeToShardedFilePattern(source_path))
        if not files:
            if len(source_paths) > 1:
                raise ValueError(
                    'Cannot find matching files with the pattern "{}" in "{}"'.
                    format(source_path, ','.join(source_paths)))
            else:
                raise ValueError(
                    'Cannot find matching files with the pattern "{}"'.format(
                        source_path))
        for f in files:
            try:
                return io_utils.read_tfrecords(f).next()
            except StopIteration:
                # Getting a StopIteration from one next() means source_path is empty.
                # Move on to the next one to try to get one example.
                pass
    return None
def examples_to_variants(examples_path, max_records=None):
    """Yields Variant protos from the examples in examples_path.

  This function reads in tf.Examples produced by DeepVariant from examples_path,
  which may contain a sharded spec, sorts them, selects a representive example
  when there are multiple versions representing different alt_alleles, and
  yields the example_variant field from those examples.

  Args:
    examples_path: str. Path, or sharded spec, to labeled tf.Examples produced
      by DeepVariant in training mode.
    max_records: int or None. Maximum number of records to read, or None, to
      read all of the records.

  Yields:
    nucleus.protos.Variant protos in coordinate-sorted order.

  Raises:
    ValueError: if we find a Variant in any example that doesn't have genotypes.
  """
    examples = io_utils.read_tfrecords(examples_path, max_records=max_records)
    variants = sorted(
        (tf_utils.example_variant(example) for example in examples),
        key=variant_utils.variant_range_tuple)

    for _, group in itertools.groupby(variants,
                                      variant_utils.variant_range_tuple):
        variant = next(group)
        if not variantcall_utils.has_genotypes(
                variant_utils.only_call(variant)):
            raise ValueError((
                'Variant {} does not have any genotypes. This tool only works with '
                'variants that have been labeled.').format(
                    variant_utils.variant_key(variant)))
        yield variant
Beispiel #9
0
  def test_realigner_diagnostics(self, enabled, emit_reads):
    # Make sure that by default we aren't emitting any diagnostic outputs.
    dx_dir = test_utils.test_tmpfile('dx')
    region_str = 'chr20:10046178-10046188'
    region = ranges.parse_literal(region_str)
    assembled_region_str = 'chr20:10046099-10046267'
    reads = _get_reads(region)
    self.config = realigner.realigner_config(FLAGS)
    self.config.diagnostics.enabled = enabled
    self.config.diagnostics.output_root = dx_dir
    self.config.diagnostics.emit_realigned_reads = emit_reads
    self.reads_realigner = realigner.Realigner(self.config, self.ref_reader)
    _, realigned_reads = self.reads_realigner.realign_reads(reads, region)
    self.reads_realigner.diagnostic_logger.close()  # Force close all resources.

    if not enabled:
      # Make sure our diagnostic output isn't emitted.
      self.assertFalse(tf.gfile.Exists(dx_dir))
    else:
      # Our root directory exists.
      self.assertTrue(tf.gfile.IsDirectory(dx_dir))

      # We expect a realigner_metrics.csv in our rootdir with 1 entry in it.
      metrics_file = os.path.join(
          dx_dir, self.reads_realigner.diagnostic_logger.metrics_filename)
      self.assertTrue(tf.gfile.Exists(metrics_file))
      with tf.gfile.FastGFile(metrics_file) as fin:
        rows = list(csv.DictReader(fin))
        self.assertEqual(len(rows), 1)
        self.assertEqual(
            set(rows[0].keys()), {'window', 'k', 'n_haplotypes', 'time'})
        self.assertEqual(rows[0]['window'], assembled_region_str)
        self.assertEqual(int(rows[0]['k']), 25)
        self.assertTrue(int(rows[0]['n_haplotypes']), 2)
        # Check that our runtime is reasonable (greater than 0, less than 10 s).
        self.assertTrue(0.0 < float(rows[0]['time']) < 10.0)

      # As does the subdirectory for this region.
      region_subdir = os.path.join(dx_dir, assembled_region_str)
      self.assertTrue(tf.gfile.IsDirectory(region_subdir))

      # We always have a graph.dot
      self.assertTrue(
          tf.gfile.Exists(
              os.path.join(
                  region_subdir,
                  self.reads_realigner.diagnostic_logger.graph_filename)))

      reads_file = os.path.join(
          dx_dir, region_str,
          self.reads_realigner.diagnostic_logger.realigned_reads_filename)
      if emit_reads:
        self.assertTrue(tf.gfile.Exists(reads_file))
        reads_from_dx = io_utils.read_tfrecords(reads_file, reads_pb2.Read)
        self.assertCountEqual(reads_from_dx, realigned_reads)
      else:
        self.assertFalse(tf.gfile.Exists(reads_file))
 def test_call_end2end_with_empty_shards(self):
   # Get only up to 10 examples.
   examples = list(
       io_utils.read_tfrecords(
           testdata.GOLDEN_CALLING_EXAMPLES, max_records=10))
   # Write to 15 shards, which means there will be multiple empty shards.
   source_path = test_utils.test_tmpfile('sharded@{}'.format(15))
   io_utils.write_tfrecords(examples, source_path)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(source_path,
                                                      len(examples))
Beispiel #11
0
 def test_call_end2end_with_empty_shards(self):
     # Get only up to 10 examples.
     examples = list(
         io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES,
                                 max_records=10))
     # Write to 15 shards, which means there will be multiple empty shards.
     source_path = test_utils.test_tmpfile('sharded@{}'.format(15))
     io_utils.write_tfrecords(examples, source_path)
     self.assertCallVariantsEmitsNRecordsForRandomGuess(
         source_path, len(examples))
 def test_call_end2end_empty_first_shard(self):
   # Get only up to 10 examples.
   examples = list(
       io_utils.read_tfrecords(
           testdata.GOLDEN_CALLING_EXAMPLES, max_records=10))
   empty_first_file = test_utils.test_tmpfile('empty_1st_shard-00000-of-00002')
   io_utils.write_tfrecords([], empty_first_file)
   second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002')
   io_utils.write_tfrecords(examples, second_file)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(
       test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))
def make_golden_dataset(compressed_inputs=False):
  if compressed_inputs:
    source_path = test_utils.test_tmpfile('make_golden_dataset.tfrecord.gz')
    io_utils.write_tfrecords(
        io_utils.read_tfrecords(testdata.GOLDEN_TRAINING_EXAMPLES), source_path)
  else:
    source_path = testdata.GOLDEN_TRAINING_EXAMPLES
  return data_providers.DeepVariantDataSet(
      name='labeled_golden',
      source=source_path,
      num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES)
Beispiel #14
0
 def test_call_end2end_empty_first_shard(self):
   # Get only up to 10 examples.
   examples = list(
       io_utils.read_tfrecords(
           testdata.GOLDEN_CALLING_EXAMPLES, max_records=10))
   empty_first_file = test_utils.test_tmpfile('empty_1st_shard-00000-of-00002')
   io_utils.write_tfrecords([], empty_first_file)
   second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002')
   io_utils.write_tfrecords(examples, second_file)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(
       test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))
Beispiel #15
0
  def test_read_tfrecords_max_records(self, filename, max_records):
    protos, path = self.write_test_protos(filename)

    # Create our generator of records from read_tfrecords.
    if max_records is None:
      expected_n = len(protos)
    else:
      expected_n = min(max_records, len(protos))
    actual = io.read_tfrecords(
        path, reference_pb2.ContigInfo, max_records=max_records)
    self.assertLen(list(actual), expected_n)
def make_golden_dataset(compressed_inputs=False):
  if compressed_inputs:
    source_path = test_utils.test_tmpfile(
        'golden.postprocess_single_site_input.tfrecord.gz')
    io_utils.write_tfrecords(
        io_utils.read_tfrecords(
            testdata.GOLDEN_POSTPROCESS_INPUT,
            proto=deepvariant_pb2.CallVariantsOutput), source_path)
  else:
    source_path = testdata.GOLDEN_POSTPROCESS_INPUT
  return source_path
Beispiel #17
0
def make_golden_dataset(compressed_inputs=False):
  if compressed_inputs:
    source_path = test_utils.test_tmpfile(
        'golden.postprocess_single_site_input.tfrecord.gz')
    io_utils.write_tfrecords(
        io_utils.read_tfrecords(
            testdata.GOLDEN_POSTPROCESS_INPUT,
            proto=deepvariant_pb2.CallVariantsOutput), source_path)
  else:
    source_path = testdata.GOLDEN_POSTPROCESS_INPUT
  return source_path
 def test_call_variants_with_no_shape(self, model):
   # Read one good record from a valid file.
   example = next(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
   # Remove image/shape.
   del example.features.feature['image/shape']
   source_path = test_utils.test_tmpfile('make_examples_out_noshape.tfrecord')
   io_utils.write_tfrecords([example], source_path)
   with self.assertRaisesRegexp(
       ValueError, 'Invalid image/shape: we expect to find an image/shape '
       'field with length 3.'):
     call_variants.prepare_inputs(source_path, model, batch_size=1)
Beispiel #19
0
    def _call_end2end_helper(self, examples_path, model, shard_inputs):
        examples = list(io_utils.read_tfrecords(examples_path))

        if shard_inputs:
            # Create a sharded version of our golden examples.
            source_path = test_utils.test_tmpfile('sharded@{}'.format(3))
            io_utils.write_tfrecords(examples, source_path)
        else:
            source_path = examples_path

        # If we point the test at a headless server, it will often be 2x2,
        # which has 8 replicas.  Otherwise a smaller batch size is fine.
        if FLAGS.use_tpu:
            batch_size = 8
        else:
            batch_size = 4

        if model.name == 'random_guess':
            # For the random guess model we can run everything.
            max_batches = None
        else:
            # For all other models we only run a single batch for inference.
            max_batches = 1

        outfile = test_utils.test_tmpfile('call_variants.tfrecord')
        call_variants.call_variants(
            examples_filename=source_path,
            checkpoint_path=_LEAVE_MODEL_UNINITIALIZED,
            model=model,
            output_file=outfile,
            batch_size=batch_size,
            max_batches=max_batches,
            master='',
            use_tpu=FLAGS.use_tpu,
        )

        call_variants_outputs = list(
            io_utils.read_tfrecords(outfile,
                                    deepvariant_pb2.CallVariantsOutput))

        return call_variants_outputs, examples, batch_size, max_batches
Beispiel #20
0
 def test_call_variants_with_no_shape(self, model):
   # Read one good record from a valid file.
   example = next(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
   # Remove image/shape.
   del example.features.feature['image/shape']
   source_path = test_utils.test_tmpfile('make_examples_out_noshape.tfrecord')
   io_utils.write_tfrecords([example], source_path)
   with self.assertRaisesRegexp(
       ValueError, 'Invalid image/shape: we expect to find an image/shape '
       'field with length 3.'):
     ds = call_variants.prepare_inputs(source_path)
     _ = list(data_providers.get_infer_batches(ds, model=model, batch_size=1))
def make_golden_dataset(compressed_inputs=False):
    if compressed_inputs:
        source_path = test_utils.test_tmpfile(
            'make_golden_dataset.tfrecord.gz')
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(testdata.GOLDEN_TRAINING_EXAMPLES),
            source_path)
    else:
        source_path = testdata.GOLDEN_TRAINING_EXAMPLES
    return data_providers.DeepVariantDataSet(
        name='labeled_golden',
        source=source_path,
        num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES)
Beispiel #22
0
    def test_make_read_writer_tfrecords(self):
        outfile = test_utils.test_tmpfile('test.tfrecord')
        writer = sam.SamWriter(outfile, header=self.header)

        # Test that the writer is a context manager and that we can write a read to
        # it.
        with writer:
            writer.write(self.read1)
            writer.write(self.read2)

        # Our output should have exactly one read in it.
        self.assertEqual([self.read1, self.read2],
                         list(
                             io_utils.read_tfrecords(outfile,
                                                     proto=reads_pb2.Read)))
Beispiel #23
0
  def test_reading_sharded_input_with_empty_shards_does_not_crash(self):
    valid_variants = io_utils.read_tfrecords(
        testdata.GOLDEN_POSTPROCESS_INPUT,
        proto=deepvariant_pb2.CallVariantsOutput)
    empty_shard_one = test_utils.test_tmpfile(
        'reading_empty_shard.tfrecord-00000-of-00002')
    none_empty_shard_two = test_utils.test_tmpfile(
        'reading_empty_shard.tfrecord-00001-of-00002')
    io_utils.write_tfrecords([], empty_shard_one)
    io_utils.write_tfrecords(valid_variants, none_empty_shard_two)
    FLAGS.infile = test_utils.test_tmpfile('reading_empty_shard.tfrecord@2')
    FLAGS.ref = testdata.CHR20_FASTA
    FLAGS.outfile = test_utils.test_tmpfile('calls_reading_empty_shard.vcf')

    postprocess_variants.main(['postprocess_variants.py'])
Beispiel #24
0
  def test_read_write_tfrecords(self, filename):
    protos, path = self.write_test_protos(filename)

    # Create our generator of records from read_tfrecords.
    reader = io.read_tfrecords(path, reference_pb2.ContigInfo)

    # Make sure it's actually a generator.
    self.assertEqual(type(reader), types.GeneratorType)

    # Check the round-trip contents.
    if '@' in filename:
      # Sharded outputs are striped across shards, so order isn't preserved.
      self.assertCountEqual(protos, reader)
    else:
      self.assertEqual(protos, list(reader))
Beispiel #25
0
  def test_make_read_writer_tfrecords(self):
    outfile = test_utils.test_tmpfile('test.tfrecord')
    writer = sam.SamWriter(outfile, header=self.header)

    # Test that the writer is a context manager and that we can write a read to
    # it.
    with writer:
      writer.write(self.read1)
      writer.write(self.read2)

    # Our output should have exactly one read in it.
    self.assertEqual([self.read1, self.read2],
                     list(
                         io_utils.read_tfrecords(outfile,
                                                 proto=reads_pb2.Read)))
 def assertCallVariantsEmitsNRecordsForRandomGuess(self, filename,
                                                   num_examples):
   outfile = test_utils.test_tmpfile('call_variants.tfrecord')
   model = modeling.get_model('random_guess')
   call_variants.call_variants(
       examples_filename=filename,
       checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
       model=model,
       output_file=outfile,
       batch_size=4,
       max_batches=None)
   call_variants_outputs = list(
       io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput))
   # Check that we have the right number of output protos.
   self.assertEqual(len(call_variants_outputs), num_examples)
  def test_reading_sharded_dataset(self, compressed_inputs):
    golden_dataset = make_golden_dataset(compressed_inputs)
    n_shards = 3
    sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards))
    io_utils.write_tfrecords(
        io_utils.read_tfrecords(golden_dataset.source), sharded_path)

    config_file = _test_dataset_config(
        'test_sharded.pbtxt',
        name='sharded_test',
        tfrecord_path=sharded_path,
        num_examples=golden_dataset.num_examples)

    self.assertDataSetExamplesMatchExpected(
        data_providers.get_dataset(config_file).get_slim_dataset(),
        golden_dataset)
Beispiel #28
0
    def test_writing_canned_records(self):
        """Tests writing all the records that are 'canned' in our tfrecord file."""
        # This file is in TFRecord format.
        tfrecord_file = test_utils.genomics_core_testdata(
            'test_features.gff.tfrecord')
        writer_options = gff_pb2.GffWriterOptions()
        gff_records = list(
            io_utils.read_tfrecords(tfrecord_file, proto=gff_pb2.GffRecord))
        out_fname = test_utils.test_tmpfile('output.gff')
        with gff_writer.GffWriter.to_file(out_fname, self.header,
                                          writer_options) as writer:
            for record in gff_records:
                writer.write(record)

        with open(out_fname) as f:
            self.assertEqual(f.readlines(), self.expected_gff_content)
  def test_writing_canned_records(self):
    """Tests writing all the variants that are 'canned' in our tfrecord file."""
    # This file is in TFRecord format.
    tfrecord_file = test_utils.genomics_core_testdata(
        'test_reads.fastq.tfrecord')

    writer_options = fastq_pb2.FastqWriterOptions()
    fastq_records = list(
        io_utils.read_tfrecords(tfrecord_file, proto=fastq_pb2.FastqRecord))
    out_fname = test_utils.test_tmpfile('output.fastq')
    with fastq_writer.FastqWriter.to_file(out_fname, writer_options) as writer:
      for record in fastq_records:
        writer.write(record)

    with gfile.GFile(out_fname, 'r') as f:
      self.assertEqual(f.readlines(), self.expected_fastq_content)
    def test_reading_sharded_dataset(self, compressed_inputs):
        golden_dataset = make_golden_dataset(compressed_inputs)
        n_shards = 3
        sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards))
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(golden_dataset.source), sharded_path)

        config_file = _test_dataset_config(
            'test_sharded.pbtxt',
            name='sharded_test',
            tfrecord_path=sharded_path,
            num_examples=golden_dataset.num_examples)

        self.assertDataSetExamplesMatchExpected(
            data_providers.get_dataset(config_file).get_slim_dataset(),
            golden_dataset)
 def assertCallVariantsEmitsNRecordsForRandomGuess(self, filename,
                                                   num_examples):
     outfile = test_utils.test_tmpfile('call_variants.tfrecord')
     model = modeling.get_model('random_guess')
     call_variants.call_variants(
         examples_filename=filename,
         checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
         model=model,
         output_file=outfile,
         batch_size=4,
         max_batches=None)
     call_variants_outputs = list(
         io_utils.read_tfrecords(outfile,
                                 deepvariant_pb2.CallVariantsOutput))
     # Check that we have the right number of output protos.
     self.assertEqual(len(call_variants_outputs), num_examples)
    def assertTfDataSetExamplesMatchExpected(self,
                                             input_fn,
                                             expected_dataset,
                                             use_tpu=False,
                                             workaround_list_files=False):
        # Note that we use input_fn to get an iterator, while we use
        # expected_dataset to get a filename, even though they are the same
        # type (DeepVariantInput), and may even be the same object.
        with tf.Session() as sess:
            params = {'batch_size': 1}
            batch_feed = input_fn(params).make_one_shot_iterator().get_next()

            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            seen = []
            while True:
                try:
                    features, _ = sess.run(batch_feed)
                except tf.errors.OutOfRangeError:
                    break
                locus = features['locus'][0]
                if use_tpu:
                    locus = tf_utils.int_tensor_to_string(locus)
                # NB, this looks like: array(['chr20:10001019-10001019'], dtype=object)
                seen.append(locus)

        if workaround_list_files:
            # This really only works for loci, because those are string valued and
            # are expected to show up in sorted order.  For arbitrary data that's
            # not true.  In prod we have the version of tf that lets us turn off
            # shuffling so this path is skipped, but kokoro hits this.
            seen = sorted(seen)

        expected_loci = [
            example.features.feature['locus'].bytes_list.value[0] for example
            in io_utils.read_tfrecords(expected_dataset.input_file_spec)
        ]
        self.assertEqual(len(expected_loci), expected_dataset.num_examples)
        if seen != expected_loci:
            print('\n\nlen expected seen', len(expected_loci), len(seen))
            print('\n\nexpected=', expected_loci)
            print('\n\nseen=', seen)
        self.assertEqual(expected_loci, seen)
        # Note that this expected shape comes from the golden dataset. If the data
        # is remade in the future, the values might need to be modified accordingly.
        self.assertEqual(dv_constants.PILEUP_DEFAULT_DIMS,
                         expected_dataset.tensor_shape)
Beispiel #33
0
  def assertCallVariantsEmitsNRecordsForInceptionV3(self, filename,
                                                    num_examples):
    outfile = test_utils.test_tmpfile('inception_v3.call_variants.tfrecord')
    model = modeling.get_model('inception_v3')
    checkpoint_path = _LEAVE_MODEL_UNINITIALIZED

    call_variants.call_variants(
        examples_filename=filename,
        checkpoint_path=checkpoint_path,
        model=model,
        output_file=outfile,
        batch_size=4,
        max_batches=None)
    call_variants_outputs = list(
        io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput))
    # Check that we have the right number of output protos.
    self.assertEqual(len(call_variants_outputs), num_examples)
  def test_writing_canned_records(self):
    """Tests writing all the records that are 'canned' in our tfrecord file."""
    # This file is in TFRecord format.
    tfrecord_file = test_utils.genomics_core_testdata(
        'test_regions.bed.tfrecord')

    header = bed_pb2.BedHeader(num_fields=12)
    writer_options = bed_pb2.BedWriterOptions()
    bed_records = list(
        io_utils.read_tfrecords(tfrecord_file, proto=bed_pb2.BedRecord))
    out_fname = test_utils.test_tmpfile('output.bed')
    with bed_writer.BedWriter.to_file(out_fname, header,
                                      writer_options) as writer:
      for record in bed_records:
        writer.write(record)

    with gfile.GFile(out_fname, 'r') as f:
      self.assertEqual(f.readlines(), self.expected_bed_content)
  def test_call_variants_with_invalid_format(self, model, bad_format):
    # Read one good record from a valid file.
    example = next(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
    # Overwrite the image/format field to be an invalid value
    # (anything but 'raw').
    example.features.feature['image/format'].bytes_list.value[0] = bad_format
    source_path = test_utils.test_tmpfile('make_examples_output.tfrecord')
    io_utils.write_tfrecords([example], source_path)
    outfile = test_utils.test_tmpfile('call_variants_invalid_format.tfrecord')

    with self.assertRaises(ValueError):
      call_variants.call_variants(
          examples_filename=source_path,
          checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
          model=model,
          output_file=outfile,
          batch_size=1,
          max_batches=1)
def make_golden_dataset(compressed_inputs=False,
                        mode=tf.estimator.ModeKeys.EVAL,
                        use_tpu=False):
    if compressed_inputs:
        source_path = test_utils.test_tmpfile(
            'make_golden_dataset.tfrecord.gz')
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(testdata.GOLDEN_TRAINING_EXAMPLES),
            source_path)
    else:
        source_path = testdata.GOLDEN_TRAINING_EXAMPLES
    return data_providers.get_input_fn_from_filespec(
        input_file_spec=source_path,
        num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES,
        name='labeled_golden',
        mode=mode,
        tensor_shape=None,
        use_tpu=use_tpu)
Beispiel #37
0
  def test_call_variants_with_invalid_format(self, model, bad_format):
    # Read one good record from a valid file.
    example = next(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
    # Overwrite the image/format field to be an invalid value
    # (anything but 'raw').
    example.features.feature['image/format'].bytes_list.value[0] = bad_format
    source_path = test_utils.test_tmpfile('make_examples_output.tfrecord')
    io_utils.write_tfrecords([example], source_path)
    outfile = test_utils.test_tmpfile('call_variants_invalid_format.tfrecord')

    with self.assertRaises(ValueError):
      call_variants.call_variants(
          examples_filename=source_path,
          checkpoint_path=_LEAVE_MODEL_UNINITIALIZED,
          model=model,
          output_file=outfile,
          batch_size=1,
          max_batches=1,
          use_tpu=FLAGS.use_tpu)
  def verify_examples(self, examples_filename, region, options, verify_labels):
    # Do some simple structural checks on the tf.Examples in the file.
    expected_features = [
        'variant/encoded', 'locus', 'image/format', 'image/encoded',
        'alt_allele_indices/encoded'
    ]
    if verify_labels:
      expected_features += ['label']

    examples = list(io_utils.read_tfrecords(examples_filename))
    for example in examples:
      for label_feature in expected_features:
        self.assertIn(label_feature, example.features.feature)
      # pylint: disable=g-explicit-length-test
      self.assertGreater(len(tf_utils.example_alt_alleles_indices(example)), 0)

    # Check that the variants in the examples are good.
    variants = [tf_utils.example_variant(x) for x in examples]
    self.verify_variants(variants, region, options, is_gvcf=False)

    return examples
    def test_reading_sharded_dataset(self, compressed_inputs, use_tpu):
        golden_dataset = make_golden_dataset(compressed_inputs,
                                             use_tpu=use_tpu)
        n_shards = 3
        sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards))
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(golden_dataset.input_file_spec),
            sharded_path)

        config_file = _test_dataset_config(
            'test_sharded.pbtxt',
            name='sharded_test',
            tfrecord_path=sharded_path,
            num_examples=golden_dataset.num_examples)

        self.assertTfDataSetExamplesMatchExpected(
            data_providers.get_input_fn_from_dataset(
                config_file, mode=tf.estimator.ModeKeys.EVAL),
            golden_dataset,
            # workaround_list_files is needed because wildcards, and so sharded
            # files, are nondeterministicly ordered (for now).
            workaround_list_files=True,
        )
Beispiel #40
0
def get_one_example_from_examples_path(source, proto=None):
    """Get the first record from `source`.

  Args:
    source: str. A pattern or a comma-separated list of patterns that represent
      file names.
    proto: A proto class. proto.FromString() will be called on each serialized
      record in path to parse it.

  Returns:
    The first record, or None.
  """
    files = io_utils.GlobListShardedFilePatterns(source)
    if not files:
        raise ValueError(
            'Cannot find matching files with the pattern "{}"'.format(source))
    for f in files:
        try:
            return io_utils.read_tfrecords(f, proto=proto).next()
        except StopIteration:
            # Getting a StopIteration from one next() means source_path is empty.
            # Move on to the next one to try to get one example.
            pass
    return None
  def test_writing_canned_variants(self):
    """Tests writing all the variants that are 'canned' in our tfrecord file."""
    # This file is in the TF record format
    tfrecord_file = test_utils.genomics_core_testdata(
        'test_samples.vcf.golden.tfrecord')

    writer_options = variants_pb2.VcfWriterOptions()
    header = variants_pb2.VcfHeader(
        contigs=[
            reference_pb2.ContigInfo(name='chr1', n_bases=248956422),
            reference_pb2.ContigInfo(name='chr2', n_bases=242193529),
            reference_pb2.ContigInfo(name='chr3', n_bases=198295559),
            reference_pb2.ContigInfo(name='chrX', n_bases=156040895)
        ],
        sample_names=['NA12878_18_99'],
        filters=[
            variants_pb2.VcfFilterInfo(
                id='PASS', description='All filters passed'),
            variants_pb2.VcfFilterInfo(id='LowQual', description=''),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL95.00to96.00'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL96.00to97.00'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL97.00to99.00'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.00to99.50'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.50to99.90'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.90to99.95'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.95to100.00+'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.95to100.00'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.50to99.60'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.60to99.80'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.80to99.90'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.90to99.95'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00+'),
            variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00'),
        ],
        infos=[
            variants_pb2.VcfInfo(
                id='END',
                number='1',
                type='Integer',
                description='Stop position of the interval')
        ],
        formats=[
            variants_pb2.VcfFormatInfo(
                id='GT', number='1', type='String', description='Genotype'),
            variants_pb2.VcfFormatInfo(
                id='GQ',
                number='1',
                type='Integer',
                description='Genotype Quality'),
            variants_pb2.VcfFormatInfo(
                id='DP',
                number='1',
                type='Integer',
                description='Read depth of all passing filters reads.'),
            variants_pb2.VcfFormatInfo(
                id='MIN_DP',
                number='1',
                type='Integer',
                description='Minimum DP observed within the GVCF block.'),
            variants_pb2.VcfFormatInfo(
                id='AD',
                number='R',
                type='Integer',
                description=
                'Read depth of all passing filters reads for each allele.'),
            variants_pb2.VcfFormatInfo(
                id='VAF',
                number='A',
                type='Float',
                description='Variant allele fractions.'),
            variants_pb2.VcfFormatInfo(
                id='GL',
                number='G',
                type='Float',
                description='Genotype likelihoods, log10 encoded'),
            variants_pb2.VcfFormatInfo(
                id='PL',
                number='G',
                type='Integer',
                description='Genotype likelihoods, Phred encoded'),
        ],
    )
    variant_records = list(
        io_utils.read_tfrecords(tfrecord_file, proto=variants_pb2.Variant))
    out_fname = test_utils.test_tmpfile('output.vcf')
    with vcf_writer.VcfWriter.to_file(out_fname, header,
                                      writer_options) as writer:
      for record in variant_records[:5]:
        writer.write(record)

    # Check: are the variants written as expected?
    # pylint: disable=line-too-long
    expected_vcf_content = [
        '##fileformat=VCFv4.2\n',
        '##FILTER=<ID=PASS,Description="All filters passed">\n',
        '##FILTER=<ID=LowQual,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL95.00to96.00,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL96.00to97.00,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL97.00to99.00,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.00to99.50,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.50to99.90,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.90to99.95,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00+,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.50to99.60,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.60to99.80,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.80to99.90,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.90to99.95,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.95to100.00+,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.95to100.00,Description="">\n',
        '##INFO=<ID=END,Number=1,Type=Integer,Description="Stop position of '
        'the interval">\n',
        '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n',
        '##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">\n',
        '##FORMAT=<ID=DP,Number=1,Type=Integer,Description="Read depth of all '
        'passing filters reads.">\n',
        '##FORMAT=<ID=MIN_DP,Number=1,Type=Integer,Description="Minimum DP '
        'observed within the GVCF block.">\n',
        '##FORMAT=<ID=AD,Number=R,Type=Integer,Description="Read depth of all '
        'passing filters reads for each allele.">\n',
        '##FORMAT=<ID=VAF,Number=A,Type=Float,Description=\"Variant allele '
        'fractions.">\n',
        '##FORMAT=<ID=GL,Number=G,Type=Float,Description="Genotype '
        'likelihoods, log10 encoded">\n',
        '##FORMAT=<ID=PL,Number=G,Type=Integer,Description="Genotype '
        'likelihoods, Phred encoded">\n',
        '##contig=<ID=chr1,length=248956422>\n',
        '##contig=<ID=chr2,length=242193529>\n',
        '##contig=<ID=chr3,length=198295559>\n',
        '##contig=<ID=chrX,length=156040895>\n',
        '#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tNA12878_18_99\n',
        'chr1\t13613\t.\tT\tA\t39.88\tVQSRTrancheSNP99.90to99.95\t.\tGT:GQ:DP:AD:PL\t0/1:16:4:1,3:68,0,16\n',
        'chr1\t13813\t.\tT\tG\t90.28\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:9:3:0,3:118,9,0\n',
        'chr1\t13838\trs28428499\tC\tT\t62.74\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:6:2:0,2:90,6,0\n',
        'chr1\t14397\trs756427959\tCTGT\tC\t37.73\tPASS\t.\tGT:GQ:DP:AD:PL\t0/1:75:5:3,2:75,0,152\n',
        'chr1\t14522\t.\tG\tA\t49.77\tVQSRTrancheSNP99.60to99.80\t.\tGT:GQ:DP:AD:PL\t0/1:78:10:6,4:78,0,118\n'
    ]
    # pylint: enable=line-too-long

    with tf.gfile.GFile(out_fname, 'r') as f:
      self.assertEqual(f.readlines(), expected_vcf_content)
    def testGoldenCallingExamples(self, use_tpu):
        # Read the golden calling examples, and read the batch_feed instantiated
        # from the golden calling examples, and ensure that we get the same
        # parsed records in both cases.

        # Read and parse the canonical data.
        expected_decoded_records = list(
            io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES,
                                    proto=example_pb2.Example))

        # Read and parse the data using tf.  This is the function under test,
        # although we indirectly check parse_tfexample as well.
        batch_feed = self.get_batch_feed(batch_size=1, use_tpu=use_tpu)

        with self.test_session() as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            n = 0
            while True:
                # Read from batch.
                try:
                    features = sess.run(batch_feed)
                except tf.errors.OutOfRangeError:
                    break

                # Get the corresponding parsed golden example.
                example = expected_decoded_records[n]
                expected_alt_allele_indices_encoded = example.features.feature[
                    'alt_allele_indices/encoded'].bytes_list.value[0]
                expected_variant_encoded = example.features.feature[
                    'variant/encoded'].bytes_list.value[0]

                # Compare against the parsed batch feed.

                a = features['image'][0]  # np.ndarray
                self.assertEqual(list(a.shape),
                                 dv_constants.PILEUP_DEFAULT_DIMS)
                self.assertIsNotNone(a)
                if use_tpu:
                    self.assertEqual(a.dtype, np.dtype('int32'))
                else:
                    self.assertEqual(a.dtype, np.dtype('uint8'))

                a = features['alt_allele_indices'][0]
                if use_tpu:
                    self.assertEqual(a.dtype, np.dtype('int32'))
                    self.assertEqual(a.shape,
                                     (tf_utils.STRING_TO_INT_BUFFER_LENGTH, ))
                    actual_alt_allele_indices_encoded = tf_utils.int_tensor_to_string(
                        a)
                else:
                    self.assertIsInstance(a, six.string_types)
                    actual_alt_allele_indices_encoded = a
                self.assertEqual(expected_alt_allele_indices_encoded,
                                 actual_alt_allele_indices_encoded)

                a = features['variant'][0]
                if use_tpu:
                    self.assertEqual(a.dtype, np.dtype('int32'))
                    self.assertEqual(a.shape,
                                     (tf_utils.STRING_TO_INT_BUFFER_LENGTH, ))
                    actual_variant_encoded = tf_utils.int_tensor_to_string(a)
                else:
                    self.assertIsInstance(a, six.string_types)
                    actual_variant_encoded = a
                self.assertEqual(expected_variant_encoded,
                                 actual_variant_encoded)

                n += 1

            self.assertEqual(n, testdata.N_GOLDEN_CALLING_EXAMPLES)
def main(argv=()):
  with errors.clean_commandline_error_exit():
    if len(argv) > 1:
      errors.log_and_raise(
          'Command line parsing failure: postprocess_variants does not accept '
          'positional arguments but some are present on the command line: '
          '"{}".'.format(str(argv)), errors.CommandLineError)
    del argv  # Unused.

    if (not FLAGS.nonvariant_site_tfrecord_path) != (not FLAGS.gvcf_outfile):
      errors.log_and_raise(
          'gVCF creation requires both nonvariant_site_tfrecord_path and '
          'gvcf_outfile flags to be set.', errors.CommandLineError)

    proto_utils.uses_fast_cpp_protos_or_die()

    logging_level.set_from_flag()

    fasta_reader = fasta.RefFastaReader(FLAGS.ref, cache_size=_FASTA_CACHE_SIZE)
    contigs = fasta_reader.header.contigs
    paths = io_utils.maybe_generate_sharded_filenames(FLAGS.infile)
    # Read one CallVariantsOutput record and extract the sample name from it.
    # Note that this assumes that all CallVariantsOutput protos in the infile
    # contain a single VariantCall within their constituent Variant proto, and
    # that the call_set_name is identical in each of the records.
    record = next(
        io_utils.read_tfrecords(
            paths[0], proto=deepvariant_pb2.CallVariantsOutput, max_records=1))
    sample_name = _extract_single_sample_name(record)
    header = dv_vcf_constants.deepvariant_header(
        contigs=contigs, sample_names=[sample_name])
    with tempfile.NamedTemporaryFile() as temp:
      postprocess_variants_lib.process_single_sites_tfrecords(
          contigs, paths, temp.name)
      independent_variants = _transform_call_variants_output_to_variants(
          input_sorted_tfrecord_path=temp.name,
          qual_filter=FLAGS.qual_filter,
          multi_allelic_qual_filter=FLAGS.multi_allelic_qual_filter,
          sample_name=sample_name)
      variant_generator = haplotypes.maybe_resolve_conflicting_variants(
          independent_variants)
      write_variants_to_vcf(
          variant_generator=variant_generator,
          output_vcf_path=FLAGS.outfile,
          header=header)

    # Also write out the gVCF file if it was provided.
    if FLAGS.nonvariant_site_tfrecord_path:
      nonvariant_generator = io_utils.read_shard_sorted_tfrecords(
          FLAGS.nonvariant_site_tfrecord_path,
          key=_get_contig_based_variant_sort_keyfn(contigs),
          proto=variants_pb2.Variant)
      with vcf.VcfReader(FLAGS.outfile, use_index=False) as variant_reader:
        lessthanfn = _get_contig_based_lessthan(contigs)
        gvcf_variants = (
            _transform_to_gvcf_record(variant)
            for variant in variant_reader.iterate())
        merged_variants = merge_variants_and_nonvariants(
            gvcf_variants, nonvariant_generator, lessthanfn, fasta_reader)
        write_variants_to_vcf(
            variant_generator=merged_variants,
            output_vcf_path=FLAGS.gvcf_outfile,
            header=header)
Beispiel #44
0
 def setUpClass(cls):
     cls.examples = list(
         io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
     cls.variants = [tf_utils.example_variant(ex) for ex in cls.examples]
     cls.model = modeling.get_model('random_guess')
  def test_call_end2end(self, model, shard_inputs, include_debug_info):
    FLAGS.include_debug_info = include_debug_info
    examples = list(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))

    if shard_inputs:
      # Create a sharded version of our golden examples.
      source_path = test_utils.test_tmpfile('sharded@{}'.format(3))
      io_utils.write_tfrecords(examples, source_path)
    else:
      source_path = testdata.GOLDEN_CALLING_EXAMPLES

    batch_size = 4
    if model.name == 'random_guess':
      # For the random guess model we can run everything.
      max_batches = None
    else:
      # For all other models we only run a single batch for inference.
      max_batches = 1

    outfile = test_utils.test_tmpfile('call_variants.tfrecord')
    call_variants.call_variants(
        examples_filename=source_path,
        checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
        model=model,
        output_file=outfile,
        batch_size=batch_size,
        max_batches=max_batches)

    call_variants_outputs = list(
        io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput))

    # Check that we have the right number of output protos.
    self.assertEqual(
        len(call_variants_outputs), batch_size * max_batches
        if max_batches else len(examples))

    # Check that our CallVariantsOutput (CVO) have the following critical
    # properties:
    # - we have one CVO for each example we processed.
    # - the variant in the CVO is exactly what was in the example.
    # - the alt_allele_indices of the CVO match those of its corresponding
    #   example.
    # - there are 3 genotype probabilities and these are between 0.0 and 1.0.
    # We can only do this test when processing all of the variants (max_batches
    # is None), since we processed all of the examples with that model.
    if max_batches is None:
      self.assertItemsEqual([cvo.variant for cvo in call_variants_outputs],
                            [tf_utils.example_variant(ex) for ex in examples])

    # Check the CVO debug_info: not filled if include_debug_info is False;
    # else, filled by logic based on CVO.
    if not include_debug_info:
      for cvo in call_variants_outputs:
        self.assertEqual(cvo.debug_info,
                         deepvariant_pb2.CallVariantsOutput.DebugInfo())
    else:
      for cvo in call_variants_outputs:
        self.assertEqual(cvo.debug_info.has_insertion,
                         variant_utils.has_insertion(cvo.variant))
        self.assertEqual(cvo.debug_info.has_deletion,
                         variant_utils.has_deletion(cvo.variant))
        self.assertEqual(cvo.debug_info.is_snp, variant_utils.is_snp(
            cvo.variant))
        self.assertEqual(cvo.debug_info.predicted_label,
                         np.argmax(cvo.genotype_probabilities))

    def example_matches_call_variants_output(example, call_variants_output):
      return (tf_utils.example_variant(example) == call_variants_output.variant
              and tf_utils.example_alt_alleles_indices(
                  example) == call_variants_output.alt_allele_indices.indices)

    for call_variants_output in call_variants_outputs:
      # Find all matching examples.
      matches = [
          ex for ex in examples
          if example_matches_call_variants_output(ex, call_variants_output)
      ]
      # We should have exactly one match.
      self.assertEqual(len(matches), 1)
      example = matches[0]
      # Check that we've faithfully copied in the alt alleles (though currently
      # as implemented we find our example using this information so it cannot
      # fail). Included here in case that changes in the future.
      self.assertEqual(
          list(tf_utils.example_alt_alleles_indices(example)),
          list(call_variants_output.alt_allele_indices.indices))
      # We should have exactly three genotype probabilities (assuming our
      # ploidy == 2).
      self.assertEqual(len(call_variants_output.genotype_probabilities), 3)
      # These are probabilities so they should be between 0 and 1.
      self.assertTrue(
          0 <= gp <= 1 for gp in call_variants_output.genotype_probabilities)
 def setUpClass(cls):
   cls.examples = list(
       io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
   cls.variants = [tf_utils.example_variant(ex) for ex in cls.examples]
   cls.model = modeling.get_model('random_guess')
  def test_make_examples_end2end(self, mode, num_shards,
                                 labeler_algorithm=None):
    self.maxDiff = None
    self.assertIn(mode, {'calling', 'training'})
    region = ranges.parse_literal('chr20:10,000,000-10,010,000')
    FLAGS.ref = testdata.CHR20_FASTA
    FLAGS.reads = testdata.CHR20_BAM
    FLAGS.candidates = test_utils.test_tmpfile(
        _sharded('vsc.tfrecord', num_shards))
    FLAGS.examples = test_utils.test_tmpfile(
        _sharded('examples.tfrecord', num_shards))
    FLAGS.regions = [ranges.to_literal(region)]
    FLAGS.partition_size = 1000
    FLAGS.mode = mode
    FLAGS.gvcf_gq_binsize = 5
    if labeler_algorithm is not None:
      FLAGS.labeler_algorithm = labeler_algorithm

    if mode == 'calling':
      FLAGS.gvcf = test_utils.test_tmpfile(
          _sharded('gvcf.tfrecord', num_shards))
    else:
      FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
      FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED

    for task_id in range(max(num_shards, 1)):
      FLAGS.task = task_id
      options = make_examples.default_options(add_flags=True)
      make_examples.make_examples_runner(options)

    # Test that our candidates are reasonable, calling specific helper functions
    # to check lots of properties of the output.
    candidates = sorted(
        io_utils.read_tfrecords(
            FLAGS.candidates, proto=deepvariant_pb2.DeepVariantCall),
        key=lambda c: variant_utils.variant_range_tuple(c.variant))
    self.verify_deepvariant_calls(candidates, options)
    self.verify_variants(
        [call.variant for call in candidates], region, options, is_gvcf=False)

    # Verify that the variants in the examples are all good.
    examples = self.verify_examples(
        FLAGS.examples, region, options, verify_labels=mode == 'training')
    example_variants = [tf_utils.example_variant(ex) for ex in examples]
    self.verify_variants(example_variants, region, options, is_gvcf=False)

    # Verify the integrity of the examples and then check that they match our
    # golden labeled examples. Note we expect the order for both training and
    # calling modes to produce deterministic order because we fix the random
    # seed.
    if mode == 'calling':
      golden_file = _sharded(testdata.GOLDEN_CALLING_EXAMPLES, num_shards)
    else:
      golden_file = _sharded(testdata.GOLDEN_TRAINING_EXAMPLES, num_shards)
    self.assertDeepVariantExamplesEqual(
        examples, list(io_utils.read_tfrecords(golden_file)))

    if mode == 'calling':
      nist_reader = vcf.VcfReader(testdata.TRUTH_VARIANTS_VCF)
      nist_variants = list(nist_reader.query(region))
      self.verify_nist_concordance(example_variants, nist_variants)

      # Check the quality of our generated gvcf file.
      gvcfs = variant_utils.sorted_variants(
          io_utils.read_tfrecords(FLAGS.gvcf, proto=variants_pb2.Variant))
      self.verify_variants(gvcfs, region, options, is_gvcf=True)
      self.verify_contiguity(gvcfs, region)
      gvcf_golden_file = _sharded(testdata.GOLDEN_POSTPROCESS_GVCF_INPUT,
                                  num_shards)
      expected_gvcfs = list(
          io_utils.read_tfrecords(gvcf_golden_file, proto=variants_pb2.Variant))
      self.assertItemsEqual(gvcfs, expected_gvcfs)
    def test_writing_canned_variants(self):
        """Tests writing all the variants that are 'canned' in our tfrecord file."""
        # This file is in the TF record format
        tfrecord_file = test_utils.genomics_core_testdata(
            'test_samples.vcf.golden.tfrecord')

        writer_options = variants_pb2.VcfWriterOptions()
        header = variants_pb2.VcfHeader(
            contigs=[
                reference_pb2.ContigInfo(name='chr1', n_bases=248956422),
                reference_pb2.ContigInfo(name='chr2', n_bases=242193529),
                reference_pb2.ContigInfo(name='chr3', n_bases=198295559),
                reference_pb2.ContigInfo(name='chrX', n_bases=156040895)
            ],
            sample_names=['NA12878_18_99'],
            filters=[
                variants_pb2.VcfFilterInfo(id='PASS',
                                           description='All filters passed'),
                variants_pb2.VcfFilterInfo(id='LowQual', description=''),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL95.00to96.00'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL96.00to97.00'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL97.00to99.00'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.00to99.50'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.50to99.90'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.90to99.95'),
                variants_pb2.VcfFilterInfo(
                    id='VQSRTrancheINDEL99.95to100.00+'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.95to100.00'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.50to99.60'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.60to99.80'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.80to99.90'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.90to99.95'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00+'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00'),
            ],
            infos=[
                variants_pb2.VcfInfo(
                    id='END',
                    number='1',
                    type='Integer',
                    description='Stop position of the interval')
            ],
            formats=[
                variants_pb2.VcfFormatInfo(id='GT',
                                           number='1',
                                           type='String',
                                           description='Genotype'),
                variants_pb2.VcfFormatInfo(id='GQ',
                                           number='1',
                                           type='Integer',
                                           description='Genotype Quality'),
                variants_pb2.VcfFormatInfo(
                    id='DP',
                    number='1',
                    type='Integer',
                    description='Read depth of all passing filters reads.'),
                variants_pb2.VcfFormatInfo(
                    id='MIN_DP',
                    number='1',
                    type='Integer',
                    description='Minimum DP observed within the GVCF block.'),
                variants_pb2.VcfFormatInfo(
                    id='AD',
                    number='R',
                    type='Integer',
                    description=
                    'Read depth of all passing filters reads for each allele.'
                ),
                variants_pb2.VcfFormatInfo(
                    id='VAF',
                    number='A',
                    type='Float',
                    description='Variant allele fractions.'),
                variants_pb2.VcfFormatInfo(
                    id='PL',
                    number='G',
                    type='Integer',
                    description='Genotype likelihoods, Phred encoded'),
            ],
        )
        variant_records = list(
            io_utils.read_tfrecords(tfrecord_file, proto=variants_pb2.Variant))
        out_fname = test_utils.test_tmpfile('output.vcf')
        with vcf_writer.VcfWriter.to_file(out_fname, header,
                                          writer_options) as writer:
            for record in variant_records[:5]:
                writer.write(record)

        # Check: are the variants written as expected?
        # pylint: disable=line-too-long
        expected_vcf_content = [
            '##fileformat=VCFv4.2\n',
            '##FILTER=<ID=PASS,Description="All filters passed">\n',
            '##FILTER=<ID=LowQual,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL95.00to96.00,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL96.00to97.00,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL97.00to99.00,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL99.00to99.50,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL99.50to99.90,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL99.90to99.95,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00+,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00,Description="">\n',
            '##FILTER=<ID=VQSRTrancheSNP99.50to99.60,Description="">\n',
            '##FILTER=<ID=VQSRTrancheSNP99.60to99.80,Description="">\n',
            '##FILTER=<ID=VQSRTrancheSNP99.80to99.90,Description="">\n',
            '##FILTER=<ID=VQSRTrancheSNP99.90to99.95,Description="">\n',
            '##FILTER=<ID=VQSRTrancheSNP99.95to100.00+,Description="">\n',
            '##FILTER=<ID=VQSRTrancheSNP99.95to100.00,Description="">\n',
            '##INFO=<ID=END,Number=1,Type=Integer,Description="Stop position of '
            'the interval">\n',
            '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n',
            '##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">\n',
            '##FORMAT=<ID=DP,Number=1,Type=Integer,Description="Read depth of all '
            'passing filters reads.">\n',
            '##FORMAT=<ID=MIN_DP,Number=1,Type=Integer,Description="Minimum DP '
            'observed within the GVCF block.">\n',
            '##FORMAT=<ID=AD,Number=R,Type=Integer,Description="Read depth of all '
            'passing filters reads for each allele.">\n',
            '##FORMAT=<ID=VAF,Number=A,Type=Float,Description=\"Variant allele '
            'fractions.">\n',
            '##FORMAT=<ID=PL,Number=G,Type=Integer,Description="Genotype '
            'likelihoods, Phred encoded">\n',
            '##contig=<ID=chr1,length=248956422>\n',
            '##contig=<ID=chr2,length=242193529>\n',
            '##contig=<ID=chr3,length=198295559>\n',
            '##contig=<ID=chrX,length=156040895>\n',
            '#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tNA12878_18_99\n',
            'chr1\t13613\t.\tT\tA\t39.88\tVQSRTrancheSNP99.90to99.95\t.\tGT:GQ:DP:AD:PL\t0/1:16:4:1,3:68,0,16\n',
            'chr1\t13813\t.\tT\tG\t90.28\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:9:3:0,3:118,9,0\n',
            'chr1\t13838\trs28428499\tC\tT\t62.74\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:6:2:0,2:90,6,0\n',
            'chr1\t14397\trs756427959\tCTGT\tC\t37.73\tPASS\t.\tGT:GQ:DP:AD:PL\t0/1:75:5:3,2:75,0,152\n',
            'chr1\t14522\t.\tG\tA\t49.77\tVQSRTrancheSNP99.60to99.80\t.\tGT:GQ:DP:AD:PL\t0/1:78:10:6,4:78,0,118\n'
        ]
        # pylint: enable=line-too-long

        with gfile.GFile(out_fname, 'r') as f:
            self.assertEqual(f.readlines(), expected_vcf_content)