Ejemplo n.º 1
0
 def testGetNoneShapeFromEmptyExamplesPath(self, file_name_to_write,
                                           tfrecord_path_to_match):
     output_file = test_utils.test_tmpfile(file_name_to_write)
     io_utils.write_tfrecords([], output_file)
     self.assertIsNone(
         tf_utils.get_shape_from_examples_path(
             test_utils.test_tmpfile(tfrecord_path_to_match)))
Ejemplo n.º 2
0
 def testGetShapeFromExamplesPath(self, file_name_to_write,
                                  tfrecord_path_to_match):
     example = example_pb2.Example()
     valid_shape = [1, 2, 3]
     example.features.feature['image/shape'].int64_list.value.extend(
         valid_shape)
     output_file = test_utils.test_tmpfile(file_name_to_write)
     io_utils.write_tfrecords([example], output_file)
     tf_utils.get_shape_from_examples_path(
         test_utils.test_tmpfile(tfrecord_path_to_match))
 def test_call_end2end_empty_first_shard(self):
   # Get only up to 10 examples.
   examples = list(
       io_utils.read_tfrecords(
           test_utils.GOLDEN_CALLING_EXAMPLES, max_records=10))
   empty_first_file = test_utils.test_tmpfile('empty_1st_shard-00000-of-00002')
   io_utils.write_tfrecords([], empty_first_file)
   second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002')
   io_utils.write_tfrecords(examples, second_file)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(
       test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))
 def test_get_shape_from_examples_path(self, file_name_to_write,
                                       tfrecord_path_to_match):
     example = example_pb2.Example()
     valid_shape = [1, 2, 3]
     example.features.feature['image/shape'].int64_list.value.extend(
         valid_shape)
     output_file = test_utils.test_tmpfile(file_name_to_write)
     io_utils.write_tfrecords([example], output_file)
     ds = data_providers.DeepVariantDataSet(
         name='test_shape',
         source=test_utils.test_tmpfile(tfrecord_path_to_match),
         num_examples=1)
     self.assertEqual(valid_shape, ds.tensor_shape)
Ejemplo n.º 5
0
 def test_call_variants_runs_on_gpus(self, model):
     call_variants.call_variants(
         examples_filename=test_utils.GOLDEN_CALLING_EXAMPLES,
         checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
         model=model,
         execution_hardware='accelerator',
         output_file=test_utils.test_tmpfile('zzz.tfrecord'))
Ejemplo n.º 6
0
 def test_call_variants_with_empty_input(self):
     source_path = test_utils.test_tmpfile('empty.tfrecord')
     io_utils.write_tfrecords([], source_path)
     # Make sure that prepare_inputs don't crash on empty input.
     call_variants.prepare_inputs(source_path,
                                  modeling.get_model('random_guess'),
                                  batch_size=1)
def _test_dataset_config(filename, **kwargs):
    """Creates a DeepVariantDatasetConfig(**kwargs) and writes it to filename."""
    dataset_config_pbtext_filename = test_utils.test_tmpfile(filename)
    dataset_config = deepvariant_pb2.DeepVariantDatasetConfig(**kwargs)
    data_providers.write_dataset_config_to_pbtxt(
        dataset_config, dataset_config_pbtext_filename)
    return dataset_config_pbtext_filename
Ejemplo n.º 8
0
    def testModelShapes(self):
        # Builds a graph.
        v0 = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.float32, name='v0')
        v1 = tf.Variable([[[1], [2]], [[3], [4]], [[5], [6]]],
                         dtype=tf.float32,
                         name='v1')
        init_all_op = tf.initialize_all_variables()
        save = tf.train.Saver({'v0': v0, 'v1': v1})
        save_path = test_utils.test_tmpfile('ckpt_for_debug_string')
        with tf.Session() as sess:
            sess.run(init_all_op)
            # Saves a checkpoint.
            save.save(sess, save_path)

            # Model shapes without any variable requests gives you all variables.
            self.assertEqual({
                'v0': (2, 3),
                'v1': (3, 2, 1)
            }, tf_utils.model_shapes(save_path))
            # Asking for v0 gives you only v0's shape.
            self.assertEqual({'v0': (2, 3)},
                             tf_utils.model_shapes(save_path, ['v0']))
            # Asking for v1 gives you only v1's shape.
            self.assertEqual({'v1': (3, 2, 1)},
                             tf_utils.model_shapes(save_path, ['v1']))

            # Verifies model_shapes() fails for non-existent tensors.
            with self.assertRaisesRegexp(KeyError, 'v3'):
                tf_utils.model_shapes(save_path, ['v3'])
  def test_call_end2end(self, compressed_inputs):
    FLAGS.infile = make_golden_dataset(compressed_inputs)
    FLAGS.ref = test_utils.CHR20_FASTA
    FLAGS.outfile = test_utils.test_tmpfile('calls.vcf')
    FLAGS.nonvariant_site_tfrecord_path = (
        test_utils.GOLDEN_POSTPROCESS_GVCF_INPUT)
    FLAGS.gvcf_outfile = test_utils.test_tmpfile('gvcf_calls.vcf')

    postprocess_variants.main(['postprocess_variants.py'])

    self.assertEqual(
        tf.gfile.FastGFile(FLAGS.outfile).readlines(),
        tf.gfile.FastGFile(test_utils.GOLDEN_POSTPROCESS_OUTPUT).readlines())
    self.assertEqual(
        tf.gfile.FastGFile(FLAGS.gvcf_outfile).readlines(),
        tf.gfile.FastGFile(
            test_utils.GOLDEN_POSTPROCESS_GVCF_OUTPUT).readlines())
  def test_call_variants_with_invalid_format(self, model, bad_format):
    # Read one good record from a valid file.
    example = next(io_utils.read_tfrecords(test_utils.GOLDEN_CALLING_EXAMPLES))
    # Overwrite the image/format field to be an invalid value
    # (anything but 'raw').
    example.features.feature['image/format'].bytes_list.value[0] = bad_format
    source_path = test_utils.test_tmpfile('make_examples_output.tfrecord')
    io_utils.write_tfrecords([example], source_path)
    outfile = test_utils.test_tmpfile('call_variants_invalid_format.tfrecord')

    with self.assertRaises(ValueError):
      call_variants.call_variants(
          examples_filename=source_path,
          checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
          model=model,
          output_file=outfile,
          batch_size=1,
          max_batches=1)
 def _run():
   call_variants.call_variants(
       examples_filename=test_utils.GOLDEN_CALLING_EXAMPLES,
       checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
       model=self.model,
       execution_hardware=hardware_env,
       max_batches=1,
       batch_size=1,
       output_file=test_utils.test_tmpfile('zzz.tfrecord'))
 def test_call_end2end_with_empty_shards(self):
   # Get only up to 10 examples.
   examples = list(
       io_utils.read_tfrecords(
           test_utils.GOLDEN_CALLING_EXAMPLES, max_records=10))
   # Write to 15 shards, which means there will be multiple empty shards.
   source_path = test_utils.test_tmpfile('sharded@{}'.format(15))
   io_utils.write_tfrecords(examples, source_path)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(source_path,
                                                      len(examples))
 def test_call_variants_with_no_shape(self, model):
   # Read one good record from a valid file.
   example = next(io_utils.read_tfrecords(test_utils.GOLDEN_CALLING_EXAMPLES))
   # Remove image/shape.
   del example.features.feature['image/shape']
   source_path = test_utils.test_tmpfile('make_examples_out_noshape.tfrecord')
   io_utils.write_tfrecords([example], source_path)
   with self.assertRaisesRegexp(
       ValueError, 'Invalid image/shape: we expect to find an image/shape '
       'field with length 3.'):
     call_variants.prepare_inputs(source_path, model, batch_size=1)
    def test_catches_bad_flags(self):
        # Set all of the requested flag values.
        region = ranges.parse_literal('chr20:10,000,000-10,010,000')
        FLAGS.ref = test_utils.CHR20_FASTA
        FLAGS.reads = test_utils.CHR20_BAM
        FLAGS.candidates = test_utils.test_tmpfile('vsc.tfrecord')
        FLAGS.examples = test_utils.test_tmpfile('examples.tfrecord')
        FLAGS.regions = [ranges.to_literal(region)]
        FLAGS.partition_size = 1000
        FLAGS.mode = 'training'
        FLAGS.truth_variants = test_utils.TRUTH_VARIANTS_VCF
        # This is the bad flag.
        FLAGS.confident_regions = ''

        with mock.patch.object(logging, 'error') as mock_logging,\
            mock.patch.object(sys, 'exit') as mock_exit:
            make_examples.main(['make_examples.py'])
        mock_logging.assert_called_once_with(
            'confident_regions is required when in training mode.')
        mock_exit.assert_called_once_with(errno.ENOENT)
Ejemplo n.º 15
0
def make_golden_dataset(compressed_inputs=False):
    if compressed_inputs:
        source_path = test_utils.test_tmpfile(
            'golden.postprocess_single_site_input.tfrecord.gz')
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(test_utils.GOLDEN_POSTPROCESS_INPUT,
                                    proto=deepvariant_pb2.CallVariantsOutput),
            source_path)
    else:
        source_path = test_utils.GOLDEN_POSTPROCESS_INPUT
    return source_path
def make_golden_dataset(compressed_inputs=False):
    if compressed_inputs:
        source_path = test_utils.test_tmpfile(
            'make_golden_dataset.tfrecord.gz')
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(test_utils.GOLDEN_TRAINING_EXAMPLES),
            source_path)
    else:
        source_path = test_utils.GOLDEN_TRAINING_EXAMPLES
    return data_providers.DeepVariantDataSet(name='labeled_golden',
                                             source=source_path,
                                             num_examples=49)
 def test_call_variants_non_accelerated_execution_runs(self,
                                                       execution_hardware):
   # This doesn't mock out the list_devices call so it's worth keeping
   # despite being very similar to the parameterized test below.
   outfile = test_utils.test_tmpfile('call_variants_cpu_only.tfrecord')
   call_variants.call_variants(
       examples_filename=test_utils.GOLDEN_CALLING_EXAMPLES,
       checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
       model=self.model,
       execution_hardware=execution_hardware,
       max_batches=1,
       batch_size=1,
       output_file=outfile)
Ejemplo n.º 18
0
 def test_catches_bad_argv(self):
     # Define valid flags to ensure raise occurs due to argv issues.
     FLAGS.infile = make_golden_dataset(False)
     FLAGS.ref = test_utils.CHR20_FASTA
     FLAGS.outfile = test_utils.test_tmpfile('nonempty_outfile.vcf')
     with mock.patch.object(logging, 'error') as mock_logging,\
         mock.patch.object(sys, 'exit') as mock_exit:
         postprocess_variants.main(['postprocess_variants.py', 'extra_arg'])
     mock_logging.assert_called_once_with(
         'Command line parsing failure: postprocess_variants does not accept '
         'positional arguments but some are present on the command line: '
         '"[\'postprocess_variants.py\', \'extra_arg\']".')
     mock_exit.assert_called_once_with(errno.ENOENT)
 def assertCallVariantsEmitsNRecordsForRandomGuess(self, filename,
                                                   num_examples):
   outfile = test_utils.test_tmpfile('call_variants.tfrecord')
   model = modeling.get_model('random_guess')
   call_variants.call_variants(
       examples_filename=filename,
       checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
       model=model,
       output_file=outfile,
       batch_size=4,
       max_batches=None)
   call_variants_outputs = list(
       io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput))
   # Check that we have the right number of output protos.
   self.assertEqual(len(call_variants_outputs), num_examples)
    def test_reading_sharded_dataset(self, compressed_inputs):
        golden_dataset = make_golden_dataset(compressed_inputs)
        n_shards = 3
        sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards))
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(golden_dataset.source), sharded_path)

        config_file = _test_dataset_config(
            'test_sharded.pbtxt',
            name='sharded_test',
            tfrecord_path=sharded_path,
            num_examples=golden_dataset.num_examples)

        self.assertDataSetExamplesMatchExpected(
            data_providers.get_dataset(config_file).get_slim_dataset(),
            golden_dataset)
Ejemplo n.º 21
0
 def _run_tiny_training(self, model_name, dataset):
     with mock.patch(
             'deepvariant.data_providers.get_dataset') as mock_get_dataset:
         mock_get_dataset.return_value = dataset
         FLAGS.train_dir = test_utils.test_tmpfile(model_name)
         FLAGS.batch_size = 2
         FLAGS.model_name = model_name
         FLAGS.save_interval_secs = 0
         FLAGS.number_of_steps = 1
         FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt'
         FLAGS.start_from_checkpoint = ''
         model_train.parse_and_run()
         # We have a checkpoint after training.
         mock_get_dataset.assert_called_once_with(
             FLAGS.dataset_config_pbtxt)
         self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir))
  def test_prepare_inputs(self, filename, expand_to_file_pattern):
    source_path = test_utils.test_tmpfile(filename)
    io_utils.write_tfrecords(self.examples, source_path)
    if expand_to_file_pattern:
      # Transform foo@3 to foo-?????-of-00003.
      source_path = io_utils.NormalizeToShardedFilePattern(source_path)

    with self.test_session() as sess:
      _, variants, _ = call_variants.prepare_inputs(
          source_path, self.model, batch_size=1)
      sess.run(tf.local_variables_initializer())
      sess.run(tf.global_variables_initializer())

      seen_variants = []
      try:
        while True:
          seen_variants.extend(sess.run(variants))
      except tf.errors.OutOfRangeError:
        pass

      self.assertItemsEqual(self.variants,
                            variantutils.decode_variants(seen_variants))
  def test_call_end2end(self, model, shard_inputs, include_debug_info):
    FLAGS.include_debug_info = include_debug_info
    examples = list(io_utils.read_tfrecords(test_utils.GOLDEN_CALLING_EXAMPLES))

    if shard_inputs:
      # Create a sharded version of our golden examples.
      source_path = test_utils.test_tmpfile('sharded@{}'.format(3))
      io_utils.write_tfrecords(examples, source_path)
    else:
      source_path = test_utils.GOLDEN_CALLING_EXAMPLES

    batch_size = 4
    if model.name == 'random_guess':
      # For the random guess model we can run everything.
      max_batches = None
    else:
      # For all other models we only run a single batch for inference.
      max_batches = 1

    outfile = test_utils.test_tmpfile('call_variants.tfrecord')
    call_variants.call_variants(
        examples_filename=source_path,
        checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
        model=model,
        output_file=outfile,
        batch_size=batch_size,
        max_batches=max_batches)

    call_variants_outputs = list(
        io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput))

    # Check that we have the right number of output protos.
    self.assertEqual(
        len(call_variants_outputs), batch_size * max_batches
        if max_batches else len(examples))

    # Check that our CallVariantsOutput (CVO) have the following critical
    # properties:
    # - we have one CVO for each example we processed.
    # - the variant in the CVO is exactly what was in the example.
    # - the alt_allele_indices of the CVO match those of its corresponding
    #   example.
    # - there are 3 genotype probabilities and these are between 0.0 and 1.0.
    # We can only do this test when processing all of the variants (max_batches
    # is None), since we processed all of the examples with that model.
    if max_batches is None:
      self.assertItemsEqual([cvo.variant for cvo in call_variants_outputs],
                            [tf_utils.example_variant(ex) for ex in examples])

    # Check the CVO debug_info: not filled if include_debug_info is False;
    # else, filled by logic based on CVO.
    if not include_debug_info:
      for cvo in call_variants_outputs:
        self.assertEqual(cvo.debug_info,
                         deepvariant_pb2.CallVariantsOutput.DebugInfo())
    else:
      for cvo in call_variants_outputs:
        self.assertEqual(cvo.debug_info.has_insertion,
                         variantutils.has_insertion(cvo.variant))
        self.assertEqual(cvo.debug_info.has_deletion,
                         variantutils.has_deletion(cvo.variant))
        self.assertEqual(cvo.debug_info.is_snp, variantutils.is_snp(
            cvo.variant))
        self.assertEqual(cvo.debug_info.predicted_label,
                         np.argmax(cvo.genotype_probabilities))

    def example_matches_call_variants_output(example, call_variants_output):
      return (tf_utils.example_variant(example) == call_variants_output.variant
              and tf_utils.example_alt_alleles_indices(
                  example) == call_variants_output.alt_allele_indices.indices)

    for call_variants_output in call_variants_outputs:
      # Find all matching examples.
      matches = [
          ex for ex in examples
          if example_matches_call_variants_output(ex, call_variants_output)
      ]
      # We should have exactly one match.
      self.assertEqual(len(matches), 1)
      example = matches[0]
      # Check that we've faithfully copied in the alt alleles (though currently
      # as implemented we find our example using this information so it cannot
      # fail). Included here in case that changes in the future.
      self.assertEqual(
          list(tf_utils.example_alt_alleles_indices(example)),
          list(call_variants_output.alt_allele_indices.indices))
      # We should have exactly three genotype probabilities (assuming our
      # ploidy == 2).
      self.assertEqual(len(call_variants_output.genotype_probabilities), 3)
      # These are probabilities so they should be between 0 and 1.
      self.assertTrue(
          0 <= gp <= 1 for gp in call_variants_output.genotype_probabilities)
Ejemplo n.º 24
0
    def test_realigner_diagnostics(self, enabled, emit_reads):
        # Make sure that by default we aren't emitting any diagnostic outputs.
        dx_dir = test_utils.test_tmpfile('dx')
        region_str = 'chr20:10046179-10046188'
        region = ranges.parse_literal(region_str)
        assembled_region_str = 'chr20:10046109-10046257'
        reads = _get_reads(region)
        self.config = realigner.realigner_config(FLAGS)
        self.config.diagnostics.enabled = enabled
        self.config.diagnostics.output_root = dx_dir
        self.config.diagnostics.emit_realigned_reads = emit_reads
        self.reads_realigner = realigner.Realigner(self.config,
                                                   self.ref_reader)
        _, realigned_reads = self.reads_realigner.realign_reads(reads, region)
        self.reads_realigner.diagnostic_logger.close(
        )  # Force close all resources.

        if not enabled:
            # Make sure our diagnostic output isn't emitted.
            self.assertFalse(tf.gfile.Exists(dx_dir))
        else:
            # Our root directory exists.
            self.assertTrue(tf.gfile.IsDirectory(dx_dir))

            # We expect a realigner_metrics.csv in our rootdir with 1 entry in it.
            metrics_file = os.path.join(
                dx_dir,
                self.reads_realigner.diagnostic_logger.metrics_filename)
            self.assertTrue(tf.gfile.Exists(metrics_file))
            with tf.gfile.FastGFile(metrics_file) as fin:
                rows = list(csv.DictReader(fin))
                self.assertEqual(len(rows), 1)
                self.assertEqual(set(rows[0].keys()),
                                 {'window', 'k', 'n_haplotypes', 'time'})
                self.assertEqual(rows[0]['window'], assembled_region_str)
                self.assertEqual(int(rows[0]['k']), 25)
                self.assertTrue(int(rows[0]['n_haplotypes']), 2)
                # Check that our runtime is reasonable (greater than 0, less than 10 s).
                self.assertTrue(0.0 < float(rows[0]['time']) < 10.0)

            # As does the subdirectory for this region.
            region_subdir = os.path.join(dx_dir, assembled_region_str)
            self.assertTrue(tf.gfile.IsDirectory(region_subdir))

            # We always have a graph.dot
            self.assertTrue(
                tf.gfile.Exists(
                    os.path.join(
                        region_subdir, self.reads_realigner.diagnostic_logger.
                        graph_filename)))

            reads_file = os.path.join(
                dx_dir, region_str, self.reads_realigner.diagnostic_logger.
                realigned_reads_filename)
            if emit_reads:
                self.assertTrue(tf.gfile.Exists(reads_file))
                reads_from_dx = io_utils.read_tfrecords(
                    reads_file, reads_pb2.Read)
                self.assertCountEqual(reads_from_dx, realigned_reads)
            else:
                self.assertFalse(tf.gfile.Exists(reads_file))
    def test_make_examples_end2end(self, mode, num_shards):
        self.assertIn(mode, {'calling', 'training'})
        region = ranges.parse_literal('chr20:10,000,000-10,010,000')
        FLAGS.ref = test_utils.CHR20_FASTA
        FLAGS.reads = test_utils.CHR20_BAM
        FLAGS.candidates = test_utils.test_tmpfile(
            _sharded('vsc.tfrecord', num_shards))
        FLAGS.examples = test_utils.test_tmpfile(
            _sharded('examples.tfrecord', num_shards))
        FLAGS.regions = [ranges.to_literal(region)]
        FLAGS.partition_size = 1000
        FLAGS.mode = mode

        if mode == 'calling':
            FLAGS.gvcf = test_utils.test_tmpfile(
                _sharded('gvcf.tfrecord', num_shards))
        else:
            FLAGS.truth_variants = test_utils.TRUTH_VARIANTS_VCF
            FLAGS.confident_regions = test_utils.CONFIDENT_REGIONS_BED

        for task_id in range(max(num_shards, 1)):
            FLAGS.task = task_id
            options = make_examples.default_options(add_flags=True)
            make_examples.make_examples_runner(options)

        # Test that our candidates are reasonable, calling specific helper functions
        # to check lots of properties of the output.
        candidates = _sort_candidates(
            io_utils.read_tfrecords(FLAGS.candidates,
                                    proto=deepvariant_pb2.DeepVariantCall))
        self.verify_deepvariant_calls(candidates, options)
        self.verify_variants([call.variant for call in candidates],
                             region,
                             options,
                             is_gvcf=False)

        # Verify that the variants in the examples are all good.
        examples = self.verify_examples(FLAGS.examples,
                                        region,
                                        options,
                                        verify_labels=mode == 'training')
        example_variants = [tf_utils.example_variant(ex) for ex in examples]
        self.verify_variants(example_variants, region, options, is_gvcf=False)

        # Verify the integrity of the examples and then check that they match our
        # golden labeled examples. Note we expect the order for both training and
        # calling modes to produce deterministic order because we fix the random
        # seed.
        if mode == 'calling':
            golden_file = _sharded(test_utils.GOLDEN_CALLING_EXAMPLES,
                                   num_shards)
        else:
            golden_file = _sharded(test_utils.GOLDEN_TRAINING_EXAMPLES,
                                   num_shards)
        self.assertDeepVariantExamplesEqual(
            examples, list(io_utils.read_tfrecords(golden_file)))

        if mode == 'calling':
            nist_reader = genomics_io.make_vcf_reader(
                test_utils.TRUTH_VARIANTS_VCF)
            nist_variants = list(nist_reader.query(region))
            self.verify_nist_concordance(example_variants, nist_variants)

            # Check the quality of our generated gvcf file.
            gvcfs = _sort_variants(
                io_utils.read_tfrecords(FLAGS.gvcf,
                                        proto=variants_pb2.Variant))
            self.verify_variants(gvcfs, region, options, is_gvcf=True)
            self.verify_contiguity(gvcfs, region)