Exemplo n.º 1
0
def _transform_call_variants_output_to_variants(input_sorted_tfrecord_path,
                                                qual_filter,
                                                multi_allelic_qual_filter,
                                                sample_name):
    """Yields Variant protos in sorted order from CallVariantsOutput protos.

  Variants present in the input TFRecord are converted to Variant protos, with
  the following filters applied: 1) variants are omitted if their quality is
  lower than the `qual_filter` threshold. 2) multi-allelic variants omit
  individual alleles whose qualities are lower than the
  `multi_allelic_qual_filter` threshold.

  Args:
    input_sorted_tfrecord_path: str. TFRecord format file containing sorted
      CallVariantsOutput protos.
    qual_filter: double. The qual value below which to filter variants.
    multi_allelic_qual_filter: double. The qual value below which to filter
      multi-allelic variants.
    sample_name: str. Sample name to write to VCF file.

  Yields:
    Variant protos in sorted order representing the CallVariantsOutput calls.
  """
    for _, group in itertools.groupby(
            io_utils.read_tfrecords(input_sorted_tfrecord_path,
                                    proto=deepvariant_pb2.CallVariantsOutput),
            lambda x: variantutils.variant_range(x.variant)):
        outputs = list(group)
        canonical_variant, predictions = merge_predictions(
            outputs, multi_allelic_qual_filter)
        variant = add_call_to_variant(canonical_variant,
                                      predictions,
                                      qual_filter=qual_filter,
                                      sample_name=sample_name)
        yield variant
    def verify_examples(self, examples_filename, region, options,
                        verify_labels):
        # Do some simple structural checks on the tf.Examples in the file.
        expected_labels = [
            'variant/encoded', 'locus', 'image/format', 'image/encoded',
            'alt_allele_indices/encoded'
        ]
        if verify_labels:
            expected_labels += ['label', 'truth_variant/encoded']

        examples = list(io_utils.read_tfrecords(examples_filename))
        for example in examples:
            for label_feature in expected_labels:
                self.assertIn(label_feature, example.features.feature)
            # pylint: disable=g-explicit-length-test
            self.assertGreater(
                len(tf_utils.example_alt_alleles_indices(example)), 0)

            if verify_labels:
                # Check that our variant and our truth_variant both have the same start.
                self.assertEqual(
                    variantutils.variant_position(
                        tf_utils.example_variant(example)),
                    variantutils.variant_position(
                        tf_utils.example_truth_variant(example)))

        # Check that the variants in the examples are good.
        variants = [tf_utils.example_variant(x) for x in examples]
        self.verify_variants(variants, region, options, is_gvcf=False)

        return examples
def main(argv=()):
    with errors.clean_commandline_error_exit():
        if len(argv) > 1:
            errors.log_and_raise(
                'Command line parsing failure: postprocess_variants does not accept '
                'positional arguments but some are present on the command line: '
                '"{}".'.format(str(argv)), errors.CommandLineError)
        del argv  # Unused.
        proto_utils.uses_fast_cpp_protos_or_die()

        logging_level.set_from_flag()

        with genomics_io.make_ref_reader(FLAGS.ref) as reader:
            contigs = reader.contigs
        paths = io_utils.maybe_generate_sharded_filenames(FLAGS.infile)
        with tempfile.NamedTemporaryFile() as temp:
            postprocess_variants_lib.process_single_sites_tfrecords(
                contigs, paths, temp.name)
            # Read one CallVariantsOutput record and extract the sample name from it.
            # Note that this assumes that all CallVariantsOutput protos in the infile
            # contain a single VariantCall within their constituent Variant proto, and
            # that the call_set_name is identical in each of the records.
            record = next(
                io_utils.read_tfrecords(
                    paths[0],
                    proto=deepvariant_pb2.CallVariantsOutput,
                    max_records=1))
            sample_name = _extract_single_sample_name(record)
            write_call_variants_output_to_vcf(
                contigs=contigs,
                input_sorted_tfrecord_path=temp.name,
                output_vcf_path=FLAGS.outfile,
                qual_filter=FLAGS.qual_filter,
                multi_allelic_qual_filter=FLAGS.multi_allelic_qual_filter,
                sample_name=sample_name)
Exemplo n.º 4
0
def _get_one_example_from_examples_path(source):
    """Reads one record from source."""
    # redacted
    # io_utils.read_tfrecord can read wildcard file patterns.
    # The source can be a comma-separated list.
    source_paths = source.split(',')
    for source_path in source_paths:
        files = tf.gfile.Glob(
            io_utils.NormalizeToShardedFilePattern(source_path))
        if not files:
            if len(source_paths) > 1:
                raise ValueError(
                    'Cannot find matching files with the pattern "{}" in "{}"'.
                    format(source_path, ','.join(source_paths)))
            else:
                raise ValueError(
                    'Cannot find matching files with the pattern "{}"'.format(
                        source_path))
        for f in files:
            try:
                return io_utils.read_tfrecords(f).next()
            except StopIteration:
                # Getting a StopIteration from one next() means source_path is empty.
                # Move on to the next one to try to get one example.
                pass
    return None
    def assertDataSetExamplesMatchExpected(self, dataset, expected_dataset):
        with tf.Session() as sess:
            provider = slim.dataset_data_provider.DatasetDataProvider(
                expected_dataset.get_slim_dataset(),
                shuffle=False,
                reader_kwargs={
                    'options':
                    io_utils.make_tfrecord_options(expected_dataset.source)
                })
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord, sess=sess)
            image, label, locus = provider.get(['image', 'label', 'locus'])
            seen = [
                sess.run([image, label, locus])[2]
                for _ in range(expected_dataset.num_examples)
            ]
            coord.request_stop()
            coord.join(threads)

        expected_loci = [
            example.features.feature['locus'].bytes_list.value[0]
            for example in io_utils.read_tfrecords(expected_dataset.source)
        ]
        self.assertEqual(len(expected_loci), expected_dataset.num_examples)
        self.assertEqual(expected_loci, seen)
        # Note that this expected shape comes from the golden dataset. If the data
        # is remade in the future, the values might need to be modified accordingly.
        self.assertEqual([100, 221, 7], expected_dataset.tensor_shape)
 def test_call_end2end_with_empty_shards(self):
   # Get only up to 10 examples.
   examples = list(
       io_utils.read_tfrecords(
           test_utils.GOLDEN_CALLING_EXAMPLES, max_records=10))
   # Write to 15 shards, which means there will be multiple empty shards.
   source_path = test_utils.test_tmpfile('sharded@{}'.format(15))
   io_utils.write_tfrecords(examples, source_path)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(source_path,
                                                      len(examples))
Exemplo n.º 7
0
def make_golden_dataset(compressed_inputs=False):
    if compressed_inputs:
        source_path = test_utils.test_tmpfile(
            'golden.postprocess_single_site_input.tfrecord.gz')
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(test_utils.GOLDEN_POSTPROCESS_INPUT,
                                    proto=deepvariant_pb2.CallVariantsOutput),
            source_path)
    else:
        source_path = test_utils.GOLDEN_POSTPROCESS_INPUT
    return source_path
 def test_call_end2end_empty_first_shard(self):
   # Get only up to 10 examples.
   examples = list(
       io_utils.read_tfrecords(
           test_utils.GOLDEN_CALLING_EXAMPLES, max_records=10))
   empty_first_file = test_utils.test_tmpfile('empty_1st_shard-00000-of-00002')
   io_utils.write_tfrecords([], empty_first_file)
   second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002')
   io_utils.write_tfrecords(examples, second_file)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(
       test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))
 def test_call_variants_with_no_shape(self, model):
   # Read one good record from a valid file.
   example = next(io_utils.read_tfrecords(test_utils.GOLDEN_CALLING_EXAMPLES))
   # Remove image/shape.
   del example.features.feature['image/shape']
   source_path = test_utils.test_tmpfile('make_examples_out_noshape.tfrecord')
   io_utils.write_tfrecords([example], source_path)
   with self.assertRaisesRegexp(
       ValueError, 'Invalid image/shape: we expect to find an image/shape '
       'field with length 3.'):
     call_variants.prepare_inputs(source_path, model, batch_size=1)
Exemplo n.º 10
0
    def test_read_tfrecords_max_records(self, filename, max_records):
        protos, path = self.write_test_protos(filename)

        # Create our generator of records from read_tfrecords.
        if max_records is None:
            expected_n = len(protos)
        else:
            expected_n = min(max_records, len(protos))
        actual = io.read_tfrecords(path,
                                   core_pb2.ContigInfo,
                                   max_records=max_records)
        self.assertLen(list(actual), expected_n)
def make_golden_dataset(compressed_inputs=False):
    if compressed_inputs:
        source_path = test_utils.test_tmpfile(
            'make_golden_dataset.tfrecord.gz')
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(test_utils.GOLDEN_TRAINING_EXAMPLES),
            source_path)
    else:
        source_path = test_utils.GOLDEN_TRAINING_EXAMPLES
    return data_providers.DeepVariantDataSet(name='labeled_golden',
                                             source=source_path,
                                             num_examples=49)
 def assertCallVariantsEmitsNRecordsForRandomGuess(self, filename,
                                                   num_examples):
   outfile = test_utils.test_tmpfile('call_variants.tfrecord')
   model = modeling.get_model('random_guess')
   call_variants.call_variants(
       examples_filename=filename,
       checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
       model=model,
       output_file=outfile,
       batch_size=4,
       max_batches=None)
   call_variants_outputs = list(
       io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput))
   # Check that we have the right number of output protos.
   self.assertEqual(len(call_variants_outputs), num_examples)
Exemplo n.º 13
0
    def test_read_write_tfrecords(self, filename):
        protos, path = self.write_test_protos(filename)

        # Create our generator of records from read_tfrecords.
        reader = io.read_tfrecords(path, core_pb2.ContigInfo)

        # Make sure it's actually a generator.
        self.assertEqual(type(reader), types.GeneratorType)

        # Check the round-trip contents.
        if '@' in filename:
            # Sharded outputs are striped across shards, so order isn't preserved.
            self.assertCountEqual(protos, reader)
        else:
            self.assertEqual(protos, list(reader))
Exemplo n.º 14
0
  def test_make_read_writer_tfrecords(self):
    outfile = test_utils.test_tmpfile('test.tfrecord')
    writer = genomics_io.make_read_writer(outfile=outfile)

    # Test that the writer is a context manager and that we can write a read to
    # it.
    with writer:
      writer.write(self.read1)
      writer.write(self.read2)

    # Our output should have exactly one read in it.
    self.assertEqual([self.read1, self.read2],
                     list(
                         io_utils.read_tfrecords(outfile,
                                                 proto=reads_pb2.Read)))
    def test_reading_sharded_dataset(self, compressed_inputs):
        golden_dataset = make_golden_dataset(compressed_inputs)
        n_shards = 3
        sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards))
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(golden_dataset.source), sharded_path)

        config_file = _test_dataset_config(
            'test_sharded.pbtxt',
            name='sharded_test',
            tfrecord_path=sharded_path,
            num_examples=golden_dataset.num_examples)

        self.assertDataSetExamplesMatchExpected(
            data_providers.get_dataset(config_file).get_slim_dataset(),
            golden_dataset)
  def test_call_variants_with_invalid_format(self, model, bad_format):
    # Read one good record from a valid file.
    example = next(io_utils.read_tfrecords(test_utils.GOLDEN_CALLING_EXAMPLES))
    # Overwrite the image/format field to be an invalid value
    # (anything but 'raw').
    example.features.feature['image/format'].bytes_list.value[0] = bad_format
    source_path = test_utils.test_tmpfile('make_examples_output.tfrecord')
    io_utils.write_tfrecords([example], source_path)
    outfile = test_utils.test_tmpfile('call_variants_invalid_format.tfrecord')

    with self.assertRaises(ValueError):
      call_variants.call_variants(
          examples_filename=source_path,
          checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
          model=model,
          output_file=outfile,
          batch_size=1,
          max_batches=1)
def write_call_variants_output_to_vcf(contigs, input_sorted_tfrecord_path,
                                      output_vcf_path, qual_filter,
                                      multi_allelic_qual_filter, sample_name):
    """Reads CallVariantsOutput protos and writes to a VCF file.

  Variants present in the input TFRecord are converted to VCF format, with the
  following filters applied: 1) variants are omitted if their quality is lower
  than the `qual_filter` threshold. 2) multi-allelic variants omit individual
  alleles whose qualities are lower than the `multi_allelic_qual_filter`
  threshold.

  Args:
    contigs: list(ContigInfo). A list of the reference genome contigs for
      writers that need contig information.
    input_sorted_tfrecord_path: str. TFRecord format file containing sorted
      CallVariantsOutput protos.
    output_vcf_path: str. Output file in VCF format.
    qual_filter: double. The qual value below which to filter variants.
    multi_allelic_qual_filter: double. The qual value below which to filter
      multi-allelic variants.
    sample_name: str. Sample name to write to VCF file.
  """
    logging.info('Writing calls to VCF file: %s', output_vcf_path)
    sync_writer, writer_fn = genomics_io.make_variant_writer(
        output_vcf_path, contigs, samples=[sample_name], filters=FILTERS)
    with sync_writer, io_utils.AsyncWriter(writer_fn) as writer:
        for _, group in itertools.groupby(
                io_utils.read_tfrecords(
                    input_sorted_tfrecord_path,
                    proto=deepvariant_pb2.CallVariantsOutput),
                lambda x: variantutils.variant_range(x.variant)):
            outputs = list(group)
            canonical_variant, predictions = merge_predictions(
                outputs, multi_allelic_qual_filter)
            variant = add_call_to_variant(canonical_variant,
                                          predictions,
                                          qual_filter=qual_filter,
                                          sample_name=sample_name)
            writer.write(variant)
 def setUpClass(cls):
   cls.examples = list(
       io_utils.read_tfrecords(test_utils.GOLDEN_CALLING_EXAMPLES))
   cls.variants = [tf_utils.example_variant(ex) for ex in cls.examples]
   cls.model = modeling.get_model('random_guess')
  def test_call_end2end(self, model, shard_inputs, include_debug_info):
    FLAGS.include_debug_info = include_debug_info
    examples = list(io_utils.read_tfrecords(test_utils.GOLDEN_CALLING_EXAMPLES))

    if shard_inputs:
      # Create a sharded version of our golden examples.
      source_path = test_utils.test_tmpfile('sharded@{}'.format(3))
      io_utils.write_tfrecords(examples, source_path)
    else:
      source_path = test_utils.GOLDEN_CALLING_EXAMPLES

    batch_size = 4
    if model.name == 'random_guess':
      # For the random guess model we can run everything.
      max_batches = None
    else:
      # For all other models we only run a single batch for inference.
      max_batches = 1

    outfile = test_utils.test_tmpfile('call_variants.tfrecord')
    call_variants.call_variants(
        examples_filename=source_path,
        checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
        model=model,
        output_file=outfile,
        batch_size=batch_size,
        max_batches=max_batches)

    call_variants_outputs = list(
        io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput))

    # Check that we have the right number of output protos.
    self.assertEqual(
        len(call_variants_outputs), batch_size * max_batches
        if max_batches else len(examples))

    # Check that our CallVariantsOutput (CVO) have the following critical
    # properties:
    # - we have one CVO for each example we processed.
    # - the variant in the CVO is exactly what was in the example.
    # - the alt_allele_indices of the CVO match those of its corresponding
    #   example.
    # - there are 3 genotype probabilities and these are between 0.0 and 1.0.
    # We can only do this test when processing all of the variants (max_batches
    # is None), since we processed all of the examples with that model.
    if max_batches is None:
      self.assertItemsEqual([cvo.variant for cvo in call_variants_outputs],
                            [tf_utils.example_variant(ex) for ex in examples])

    # Check the CVO debug_info: not filled if include_debug_info is False;
    # else, filled by logic based on CVO.
    if not include_debug_info:
      for cvo in call_variants_outputs:
        self.assertEqual(cvo.debug_info,
                         deepvariant_pb2.CallVariantsOutput.DebugInfo())
    else:
      for cvo in call_variants_outputs:
        self.assertEqual(cvo.debug_info.has_insertion,
                         variantutils.has_insertion(cvo.variant))
        self.assertEqual(cvo.debug_info.has_deletion,
                         variantutils.has_deletion(cvo.variant))
        self.assertEqual(cvo.debug_info.is_snp, variantutils.is_snp(
            cvo.variant))
        self.assertEqual(cvo.debug_info.predicted_label,
                         np.argmax(cvo.genotype_probabilities))

    def example_matches_call_variants_output(example, call_variants_output):
      return (tf_utils.example_variant(example) == call_variants_output.variant
              and tf_utils.example_alt_alleles_indices(
                  example) == call_variants_output.alt_allele_indices.indices)

    for call_variants_output in call_variants_outputs:
      # Find all matching examples.
      matches = [
          ex for ex in examples
          if example_matches_call_variants_output(ex, call_variants_output)
      ]
      # We should have exactly one match.
      self.assertEqual(len(matches), 1)
      example = matches[0]
      # Check that we've faithfully copied in the alt alleles (though currently
      # as implemented we find our example using this information so it cannot
      # fail). Included here in case that changes in the future.
      self.assertEqual(
          list(tf_utils.example_alt_alleles_indices(example)),
          list(call_variants_output.alt_allele_indices.indices))
      # We should have exactly three genotype probabilities (assuming our
      # ploidy == 2).
      self.assertEqual(len(call_variants_output.genotype_probabilities), 3)
      # These are probabilities so they should be between 0 and 1.
      self.assertTrue(
          0 <= gp <= 1 for gp in call_variants_output.genotype_probabilities)
Exemplo n.º 20
0
    def test_realigner_diagnostics(self, enabled, emit_reads):
        # Make sure that by default we aren't emitting any diagnostic outputs.
        dx_dir = test_utils.test_tmpfile('dx')
        region_str = 'chr20:10046179-10046188'
        region = ranges.parse_literal(region_str)
        assembled_region_str = 'chr20:10046109-10046257'
        reads = _get_reads(region)
        self.config = realigner.realigner_config(FLAGS)
        self.config.diagnostics.enabled = enabled
        self.config.diagnostics.output_root = dx_dir
        self.config.diagnostics.emit_realigned_reads = emit_reads
        self.reads_realigner = realigner.Realigner(self.config,
                                                   self.ref_reader)
        _, realigned_reads = self.reads_realigner.realign_reads(reads, region)
        self.reads_realigner.diagnostic_logger.close(
        )  # Force close all resources.

        if not enabled:
            # Make sure our diagnostic output isn't emitted.
            self.assertFalse(tf.gfile.Exists(dx_dir))
        else:
            # Our root directory exists.
            self.assertTrue(tf.gfile.IsDirectory(dx_dir))

            # We expect a realigner_metrics.csv in our rootdir with 1 entry in it.
            metrics_file = os.path.join(
                dx_dir,
                self.reads_realigner.diagnostic_logger.metrics_filename)
            self.assertTrue(tf.gfile.Exists(metrics_file))
            with tf.gfile.FastGFile(metrics_file) as fin:
                rows = list(csv.DictReader(fin))
                self.assertEqual(len(rows), 1)
                self.assertEqual(set(rows[0].keys()),
                                 {'window', 'k', 'n_haplotypes', 'time'})
                self.assertEqual(rows[0]['window'], assembled_region_str)
                self.assertEqual(int(rows[0]['k']), 25)
                self.assertTrue(int(rows[0]['n_haplotypes']), 2)
                # Check that our runtime is reasonable (greater than 0, less than 10 s).
                self.assertTrue(0.0 < float(rows[0]['time']) < 10.0)

            # As does the subdirectory for this region.
            region_subdir = os.path.join(dx_dir, assembled_region_str)
            self.assertTrue(tf.gfile.IsDirectory(region_subdir))

            # We always have a graph.dot
            self.assertTrue(
                tf.gfile.Exists(
                    os.path.join(
                        region_subdir, self.reads_realigner.diagnostic_logger.
                        graph_filename)))

            reads_file = os.path.join(
                dx_dir, region_str, self.reads_realigner.diagnostic_logger.
                realigned_reads_filename)
            if emit_reads:
                self.assertTrue(tf.gfile.Exists(reads_file))
                reads_from_dx = io_utils.read_tfrecords(
                    reads_file, reads_pb2.Read)
                self.assertCountEqual(reads_from_dx, realigned_reads)
            else:
                self.assertFalse(tf.gfile.Exists(reads_file))
Exemplo n.º 21
0
def main(argv=()):
    with errors.clean_commandline_error_exit():
        if len(argv) > 1:
            errors.log_and_raise(
                'Command line parsing failure: postprocess_variants does not accept '
                'positional arguments but some are present on the command line: '
                '"{}".'.format(str(argv)), errors.CommandLineError)
        del argv  # Unused.

        if (not FLAGS.nonvariant_site_tfrecord_path) != (
                not FLAGS.gvcf_outfile):
            errors.log_and_raise(
                'gVCF creation requires both nonvariant_site_tfrecord_path and '
                'gvcf_outfile flags to be set.', errors.CommandLineError)

        proto_utils.uses_fast_cpp_protos_or_die()

        logging_level.set_from_flag()

        with genomics_io.make_ref_reader(FLAGS.ref) as reader:
            contigs = reader.contigs
        paths = io_utils.maybe_generate_sharded_filenames(FLAGS.infile)
        with tempfile.NamedTemporaryFile() as temp:
            postprocess_variants_lib.process_single_sites_tfrecords(
                contigs, paths, temp.name)
            # Read one CallVariantsOutput record and extract the sample name from it.
            # Note that this assumes that all CallVariantsOutput protos in the infile
            # contain a single VariantCall within their constituent Variant proto, and
            # that the call_set_name is identical in each of the records.
            record = next(
                io_utils.read_tfrecords(
                    paths[0],
                    proto=deepvariant_pb2.CallVariantsOutput,
                    max_records=1))
            sample_name = _extract_single_sample_name(record)
            independent_variants = _transform_call_variants_output_to_variants(
                input_sorted_tfrecord_path=temp.name,
                qual_filter=FLAGS.qual_filter,
                multi_allelic_qual_filter=FLAGS.multi_allelic_qual_filter,
                sample_name=sample_name)
            variant_generator = haplotypes.maybe_resolve_conflicting_variants(
                independent_variants)
            write_variants_to_vcf(contigs=contigs,
                                  variant_generator=variant_generator,
                                  output_vcf_path=FLAGS.outfile,
                                  sample_name=sample_name)

        # Also write out the gVCF file if it was provided.
        if FLAGS.nonvariant_site_tfrecord_path:
            nonvariant_generator = io_utils.read_shard_sorted_tfrecords(
                FLAGS.nonvariant_site_tfrecord_path,
                key=_get_contig_based_variant_sort_keyfn(contigs),
                proto=variants_pb2.Variant)
            with genomics_io.make_vcf_reader(
                    FLAGS.outfile, use_index=False,
                    include_likelihoods=True) as variant_reader:
                lessthanfn = _get_contig_based_lessthan(variant_reader.contigs)
                gvcf_variants = (_transform_to_gvcf_record(variant)
                                 for variant in variant_reader.iterate())
                merged_variants = merge_variants_and_nonvariants(
                    gvcf_variants, nonvariant_generator, lessthanfn)
                write_variants_to_vcf(contigs=contigs,
                                      variant_generator=merged_variants,
                                      output_vcf_path=FLAGS.gvcf_outfile,
                                      sample_name=sample_name,
                                      filters=FILTERS)
Exemplo n.º 22
0
  def test_writing_canned_variants(self):
    """Tests writing all the variants that are 'canned' in our tfrecord file."""

    # This file is in the TF record format
    tfrecord_file = test_utils.genomics_core_testdata(
        'test_samples.vcf.golden.tfrecord')

    writer_options = core_pb2.VcfWriterOptions(
        contigs=[
            core_pb2.ContigInfo(name='chr1', n_bases=248956422),
            core_pb2.ContigInfo(name='chr2', n_bases=242193529),
            core_pb2.ContigInfo(name='chr3', n_bases=198295559),
            core_pb2.ContigInfo(name='chrX', n_bases=156040895)
        ],
        sample_names=['NA12878_18_99'],
        filters=[
            core_pb2.VcfFilterInfo(id='LowQual'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheINDEL95.00to96.00'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheINDEL96.00to97.00'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheINDEL97.00to99.00'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.00to99.50'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.50to99.90'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.90to99.95'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.95to100.00+'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.95to100.00'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.50to99.60'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.60to99.80'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.80to99.90'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.90to99.95'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00+'),
            core_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00'),
        ])

    variant_records = list(
        io_utils.read_tfrecords(tfrecord_file, proto=variants_pb2.Variant))
    out_fname = test_utils.test_tmpfile('output.vcf')
    with vcf_writer.VcfWriter.to_file(out_fname, writer_options) as writer:
      for record in variant_records[:5]:
        writer.write(record)

    # Check: are the variants written as expected?
    # pylint: disable=line-too-long
    expected_vcf_content = [
        '##fileformat=VCFv4.2\n',
        '##FILTER=<ID=PASS,Description="All filters passed">\n',
        '##FILTER=<ID=LowQual,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL95.00to96.00,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL96.00to97.00,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL97.00to99.00,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.00to99.50,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.50to99.90,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.90to99.95,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00+,Description="">\n',
        '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.50to99.60,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.60to99.80,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.80to99.90,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.90to99.95,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.95to100.00+,Description="">\n',
        '##FILTER=<ID=VQSRTrancheSNP99.95to100.00,Description="">\n',
        '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n',
        '##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">\n',
        '##FORMAT=<ID=DP,Number=1,Type=Integer,Description="Read depth of all '
        'passing filters reads.">\n',
        '##FORMAT=<ID=AD,Number=R,Type=Integer,Description="Read depth of all '
        'passing filters reads for each allele.">\n',
        '##FORMAT=<ID=VAF,Number=A,Type=Float,Description=\"Variant allele '
        'fractions.">\n',
        '##FORMAT=<ID=GL,Number=G,Type=Float,Description="Genotype '
        'likelihoods, log10 encoded">\n',
        '##FORMAT=<ID=PL,Number=G,Type=Integer,Description="Genotype '
        'likelihoods, Phred encoded">\n',
        '##INFO=<ID=END,Number=1,Type=Integer,Description="Stop position of '
        'the interval">\n', '##contig=<ID=chr1,length=248956422>\n',
        '##contig=<ID=chr2,length=242193529>\n',
        '##contig=<ID=chr3,length=198295559>\n',
        '##contig=<ID=chrX,length=156040895>\n',
        '#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tNA12878_18_99\n',
        'chr1\t13613\t.\tT\tA\t39.88\tVQSRTrancheSNP99.90to99.95\t.\tGT:GQ:DP:AD:PL\t0/1:16:4:1,3:68,0,16\n',
        'chr1\t13813\t.\tT\tG\t90.28\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:9:3:0,3:118,9,0\n',
        'chr1\t13838\trs28428499\tC\tT\t62.74\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:6:2:0,2:90,6,0\n',
        'chr1\t14397\trs756427959\tCTGT\tC\t37.73\tPASS\t.\tGT:GQ:DP:AD:PL\t0/1:75:5:3,2:75,0,152\n',
        'chr1\t14522\t.\tG\tA\t49.77\tVQSRTrancheSNP99.60to99.80\t.\tGT:GQ:DP:AD:PL\t0/1:78:10:6,4:78,0,118\n'
    ]
    # pylint: enable=line-too-long

    with tf.gfile.GFile(out_fname, 'r') as f:
      self.assertEqual(f.readlines(), expected_vcf_content)
    def test_make_examples_end2end(self, mode, num_shards):
        self.assertIn(mode, {'calling', 'training'})
        region = ranges.parse_literal('chr20:10,000,000-10,010,000')
        FLAGS.ref = test_utils.CHR20_FASTA
        FLAGS.reads = test_utils.CHR20_BAM
        FLAGS.candidates = test_utils.test_tmpfile(
            _sharded('vsc.tfrecord', num_shards))
        FLAGS.examples = test_utils.test_tmpfile(
            _sharded('examples.tfrecord', num_shards))
        FLAGS.regions = [ranges.to_literal(region)]
        FLAGS.partition_size = 1000
        FLAGS.mode = mode

        if mode == 'calling':
            FLAGS.gvcf = test_utils.test_tmpfile(
                _sharded('gvcf.tfrecord', num_shards))
        else:
            FLAGS.truth_variants = test_utils.TRUTH_VARIANTS_VCF
            FLAGS.confident_regions = test_utils.CONFIDENT_REGIONS_BED

        for task_id in range(max(num_shards, 1)):
            FLAGS.task = task_id
            options = make_examples.default_options(add_flags=True)
            make_examples.make_examples_runner(options)

        # Test that our candidates are reasonable, calling specific helper functions
        # to check lots of properties of the output.
        candidates = _sort_candidates(
            io_utils.read_tfrecords(FLAGS.candidates,
                                    proto=deepvariant_pb2.DeepVariantCall))
        self.verify_deepvariant_calls(candidates, options)
        self.verify_variants([call.variant for call in candidates],
                             region,
                             options,
                             is_gvcf=False)

        # Verify that the variants in the examples are all good.
        examples = self.verify_examples(FLAGS.examples,
                                        region,
                                        options,
                                        verify_labels=mode == 'training')
        example_variants = [tf_utils.example_variant(ex) for ex in examples]
        self.verify_variants(example_variants, region, options, is_gvcf=False)

        # Verify the integrity of the examples and then check that they match our
        # golden labeled examples. Note we expect the order for both training and
        # calling modes to produce deterministic order because we fix the random
        # seed.
        if mode == 'calling':
            golden_file = _sharded(test_utils.GOLDEN_CALLING_EXAMPLES,
                                   num_shards)
        else:
            golden_file = _sharded(test_utils.GOLDEN_TRAINING_EXAMPLES,
                                   num_shards)
        self.assertDeepVariantExamplesEqual(
            examples, list(io_utils.read_tfrecords(golden_file)))

        if mode == 'calling':
            nist_reader = genomics_io.make_vcf_reader(
                test_utils.TRUTH_VARIANTS_VCF)
            nist_variants = list(nist_reader.query(region))
            self.verify_nist_concordance(example_variants, nist_variants)

            # Check the quality of our generated gvcf file.
            gvcfs = _sort_variants(
                io_utils.read_tfrecords(FLAGS.gvcf,
                                        proto=variants_pb2.Variant))
            self.verify_variants(gvcfs, region, options, is_gvcf=True)
            self.verify_contiguity(gvcfs, region)