Exemple #1
0
    def test_prepare_inputs(self, filename_to_write, file_string_input):
        source_path = test_utils.test_tmpfile(filename_to_write)
        io_utils.write_tfrecords(self.examples, source_path)
        # file_string_input could be a comma-separated list. Add the prefix to all
        # of them, and join it back to a string.
        file_string_input = ','.join(
            [test_utils.test_tmpfile(f) for f in file_string_input.split(',')])

        with self.test_session() as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            ds = call_variants.prepare_inputs(file_string_input)
            _, variants, _ = data_providers.get_infer_batches(ds,
                                                              model=self.model,
                                                              batch_size=1)

            seen_variants = []
            try:
                while True:
                    seen_variants.extend(sess.run(variants))
            except tf.errors.OutOfRangeError:
                pass

            self.assertItemsEqual(self.variants,
                                  variant_utils.decode_variants(seen_variants))
Exemple #2
0
 def testGetNoneShapeFromEmptyExamplesPath(self, file_name_to_write,
                                           tfrecord_path_to_match):
     output_file = test_utils.test_tmpfile(file_name_to_write)
     io_utils.write_tfrecords([], output_file)
     self.assertIsNone(
         tf_utils.get_shape_from_examples_path(
             test_utils.test_tmpfile(tfrecord_path_to_match)))
    def test_prepare_inputs(self, filename, expand_to_file_pattern):
        source_path = test_utils.test_tmpfile(filename)
        io_utils.write_tfrecords(self.examples, source_path)
        if expand_to_file_pattern:
            # Transform foo@3 to foo-?????-of-00003.
            source_path = io_utils.NormalizeToShardedFilePattern(source_path)

        with self.test_session() as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            ds = call_variants.prepare_inputs(source_path)
            _, variants, _ = data_providers.get_infer_batches(ds,
                                                              model=self.model,
                                                              batch_size=1)

            seen_variants = []
            try:
                while True:
                    seen_variants.extend(sess.run(variants))
            except tf.errors.OutOfRangeError:
                pass

            self.assertItemsEqual(self.variants,
                                  variant_utils.decode_variants(seen_variants))
 def test_call_variants_with_empty_input(self):
     source_path = test_utils.test_tmpfile('empty.tfrecord')
     io_utils.write_tfrecords([], source_path)
     # Make sure that prepare_inputs don't crash on empty input.
     call_variants.prepare_inputs(source_path,
                                  modeling.get_model('random_guess'),
                                  batch_size=1)
Exemple #5
0
 def testGetShapeFromExamplesPath(self, file_name_to_write,
                                  tfrecord_path_to_match):
   example = example_pb2.Example()
   valid_shape = [1, 2, 3]
   example.features.feature['image/shape'].int64_list.value.extend(valid_shape)
   output_file = test_utils.test_tmpfile(file_name_to_write)
   io_utils.write_tfrecords([example], output_file)
   tf_utils.get_shape_from_examples_path(
       test_utils.test_tmpfile(tfrecord_path_to_match))
 def test_call_end2end_with_empty_shards(self):
   # Get only up to 10 examples.
   examples = list(
       io_utils.read_tfrecords(
           testdata.GOLDEN_CALLING_EXAMPLES, max_records=10))
   # Write to 15 shards, which means there will be multiple empty shards.
   source_path = test_utils.test_tmpfile('sharded@{}'.format(15))
   io_utils.write_tfrecords(examples, source_path)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(source_path,
                                                      len(examples))
Exemple #7
0
 def test_call_end2end_with_empty_shards(self):
     # Get only up to 10 examples.
     examples = list(
         io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES,
                                 max_records=10))
     # Write to 15 shards, which means there will be multiple empty shards.
     source_path = test_utils.test_tmpfile('sharded@{}'.format(15))
     io_utils.write_tfrecords(examples, source_path)
     self.assertCallVariantsEmitsNRecordsForRandomGuess(
         source_path, len(examples))
def make_golden_dataset(compressed_inputs=False):
  if compressed_inputs:
    source_path = test_utils.test_tmpfile(
        'golden.postprocess_single_site_input.tfrecord.gz')
    io_utils.write_tfrecords(
        io_utils.read_tfrecords(
            testdata.GOLDEN_POSTPROCESS_INPUT,
            proto=deepvariant_pb2.CallVariantsOutput), source_path)
  else:
    source_path = testdata.GOLDEN_POSTPROCESS_INPUT
  return source_path
def make_golden_dataset(compressed_inputs=False):
  if compressed_inputs:
    source_path = test_utils.test_tmpfile('make_golden_dataset.tfrecord.gz')
    io_utils.write_tfrecords(
        io_utils.read_tfrecords(testdata.GOLDEN_TRAINING_EXAMPLES), source_path)
  else:
    source_path = testdata.GOLDEN_TRAINING_EXAMPLES
  return data_providers.DeepVariantDataSet(
      name='labeled_golden',
      source=source_path,
      num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES)
Exemple #10
0
 def test_call_end2end_empty_first_shard(self):
   # Get only up to 10 examples.
   examples = list(
       io_utils.read_tfrecords(
           testdata.GOLDEN_CALLING_EXAMPLES, max_records=10))
   empty_first_file = test_utils.test_tmpfile('empty_1st_shard-00000-of-00002')
   io_utils.write_tfrecords([], empty_first_file)
   second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002')
   io_utils.write_tfrecords(examples, second_file)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(
       test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))
 def test_call_end2end_empty_first_shard(self):
   # Get only up to 10 examples.
   examples = list(
       io_utils.read_tfrecords(
           testdata.GOLDEN_CALLING_EXAMPLES, max_records=10))
   empty_first_file = test_utils.test_tmpfile('empty_1st_shard-00000-of-00002')
   io_utils.write_tfrecords([], empty_first_file)
   second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002')
   io_utils.write_tfrecords(examples, second_file)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(
       test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))
Exemple #12
0
def make_golden_dataset(compressed_inputs=False):
  if compressed_inputs:
    source_path = test_utils.test_tmpfile(
        'golden.postprocess_single_site_input.tfrecord.gz')
    io_utils.write_tfrecords(
        io_utils.read_tfrecords(
            testdata.GOLDEN_POSTPROCESS_INPUT,
            proto=deepvariant_pb2.CallVariantsOutput), source_path)
  else:
    source_path = testdata.GOLDEN_POSTPROCESS_INPUT
  return source_path
 def test_call_variants_with_no_shape(self, model):
   # Read one good record from a valid file.
   example = next(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
   # Remove image/shape.
   del example.features.feature['image/shape']
   source_path = test_utils.test_tmpfile('make_examples_out_noshape.tfrecord')
   io_utils.write_tfrecords([example], source_path)
   with self.assertRaisesRegexp(
       ValueError, 'Invalid image/shape: we expect to find an image/shape '
       'field with length 3.'):
     call_variants.prepare_inputs(source_path, model, batch_size=1)
 def test_get_shape_from_examples_path(self, file_name_to_write,
                                       tfrecord_path_to_match):
   example = example_pb2.Example()
   valid_shape = [1, 2, 3]
   example.features.feature['image/shape'].int64_list.value.extend(valid_shape)
   output_file = test_utils.test_tmpfile(file_name_to_write)
   io_utils.write_tfrecords([example], output_file)
   ds = data_providers.DeepVariantDataSet(
       name='test_shape',
       source=test_utils.test_tmpfile(tfrecord_path_to_match),
       num_examples=1)
   self.assertEqual(valid_shape, ds.tensor_shape)
Exemple #15
0
 def test_call_variants_with_no_shape(self, model):
   # Read one good record from a valid file.
   example = next(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
   # Remove image/shape.
   del example.features.feature['image/shape']
   source_path = test_utils.test_tmpfile('make_examples_out_noshape.tfrecord')
   io_utils.write_tfrecords([example], source_path)
   with self.assertRaisesRegexp(
       ValueError, 'Invalid image/shape: we expect to find an image/shape '
       'field with length 3.'):
     ds = call_variants.prepare_inputs(source_path)
     _ = list(data_providers.get_infer_batches(ds, model=model, batch_size=1))
Exemple #16
0
  def test_reading_empty_input_should_raise_error(self):
    empty_shard_one = test_utils.test_tmpfile(
        'no_records.tfrecord-00000-of-00002')
    empty_shard_two = test_utils.test_tmpfile(
        'no_records.tfrecord-00001-of-00002')
    io_utils.write_tfrecords([], empty_shard_one)
    io_utils.write_tfrecords([], empty_shard_two)
    FLAGS.infile = test_utils.test_tmpfile('no_records.tfrecord@2')
    FLAGS.ref = testdata.CHR20_FASTA
    FLAGS.outfile = test_utils.test_tmpfile('no_records.vcf')

    with self.assertRaisesRegexp(ValueError, 'Cannot find any records in'):
      postprocess_variants.main(['postprocess_variants.py'])
 def test_get_shape_from_examples_path(self, file_name_to_write,
                                       tfrecord_path_to_match):
     example = example_pb2.Example()
     valid_shape = [1, 2, 3]
     example.features.feature['image/shape'].int64_list.value.extend(
         valid_shape)
     output_file = test_utils.test_tmpfile(file_name_to_write)
     io_utils.write_tfrecords([example], output_file)
     ds = data_providers.DeepVariantDataSet(
         name='test_shape',
         source=test_utils.test_tmpfile(tfrecord_path_to_match),
         num_examples=1)
     self.assertEqual(valid_shape, ds.tensor_shape)
def make_golden_dataset(compressed_inputs=False):
    if compressed_inputs:
        source_path = test_utils.test_tmpfile(
            'make_golden_dataset.tfrecord.gz')
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(testdata.GOLDEN_TRAINING_EXAMPLES),
            source_path)
    else:
        source_path = testdata.GOLDEN_TRAINING_EXAMPLES
    return data_providers.DeepVariantDataSet(
        name='labeled_golden',
        source=source_path,
        num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES)
Exemple #19
0
  def test_reading_sharded_input_with_empty_shards_does_not_crash(self):
    valid_variants = io_utils.read_tfrecords(
        testdata.GOLDEN_POSTPROCESS_INPUT,
        proto=deepvariant_pb2.CallVariantsOutput)
    empty_shard_one = test_utils.test_tmpfile(
        'reading_empty_shard.tfrecord-00000-of-00002')
    none_empty_shard_two = test_utils.test_tmpfile(
        'reading_empty_shard.tfrecord-00001-of-00002')
    io_utils.write_tfrecords([], empty_shard_one)
    io_utils.write_tfrecords(valid_variants, none_empty_shard_two)
    FLAGS.infile = test_utils.test_tmpfile('reading_empty_shard.tfrecord@2')
    FLAGS.ref = testdata.CHR20_FASTA
    FLAGS.outfile = test_utils.test_tmpfile('calls_reading_empty_shard.vcf')

    postprocess_variants.main(['postprocess_variants.py'])
    def test_reading_sharded_dataset(self, compressed_inputs):
        golden_dataset = make_golden_dataset(compressed_inputs)
        n_shards = 3
        sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards))
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(golden_dataset.source), sharded_path)

        config_file = _test_dataset_config(
            'test_sharded.pbtxt',
            name='sharded_test',
            tfrecord_path=sharded_path,
            num_examples=golden_dataset.num_examples)

        self.assertDataSetExamplesMatchExpected(
            data_providers.get_dataset(config_file).get_slim_dataset(),
            golden_dataset)
  def test_reading_sharded_dataset(self, compressed_inputs):
    golden_dataset = make_golden_dataset(compressed_inputs)
    n_shards = 3
    sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards))
    io_utils.write_tfrecords(
        io_utils.read_tfrecords(golden_dataset.source), sharded_path)

    config_file = _test_dataset_config(
        'test_sharded.pbtxt',
        name='sharded_test',
        tfrecord_path=sharded_path,
        num_examples=golden_dataset.num_examples)

    self.assertDataSetExamplesMatchExpected(
        data_providers.get_dataset(config_file).get_slim_dataset(),
        golden_dataset)
Exemple #22
0
  def test_call_variants_with_empty_input(self):
    source_path = test_utils.test_tmpfile('empty.tfrecord')
    io_utils.write_tfrecords([], source_path)
    # Make sure that prepare_inputs don't crash on empty input.
    ds = call_variants.prepare_inputs(source_path)
    m = modeling.get_model('random_guess')

    # The API specifies that OutOfRangeError is thrown in this case.
    batches = list(data_providers.get_infer_batches(ds, model=m, batch_size=1))
    with self.test_session() as sess:
      sess.run(tf.local_variables_initializer())
      sess.run(tf.global_variables_initializer())
      try:
        _ = sess.run(batches)
      except tf.errors.OutOfRangeError:
        pass
  def test_call_variants_with_invalid_format(self, model, bad_format):
    # Read one good record from a valid file.
    example = next(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
    # Overwrite the image/format field to be an invalid value
    # (anything but 'raw').
    example.features.feature['image/format'].bytes_list.value[0] = bad_format
    source_path = test_utils.test_tmpfile('make_examples_output.tfrecord')
    io_utils.write_tfrecords([example], source_path)
    outfile = test_utils.test_tmpfile('call_variants_invalid_format.tfrecord')

    with self.assertRaises(ValueError):
      call_variants.call_variants(
          examples_filename=source_path,
          checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
          model=model,
          output_file=outfile,
          batch_size=1,
          max_batches=1)
def make_golden_dataset(compressed_inputs=False,
                        mode=tf.estimator.ModeKeys.EVAL,
                        use_tpu=False):
    if compressed_inputs:
        source_path = test_utils.test_tmpfile(
            'make_golden_dataset.tfrecord.gz')
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(testdata.GOLDEN_TRAINING_EXAMPLES),
            source_path)
    else:
        source_path = testdata.GOLDEN_TRAINING_EXAMPLES
    return data_providers.get_input_fn_from_filespec(
        input_file_spec=source_path,
        num_examples=testdata.N_GOLDEN_TRAINING_EXAMPLES,
        name='labeled_golden',
        mode=mode,
        tensor_shape=None,
        use_tpu=use_tpu)
Exemple #25
0
  def test_call_variants_with_invalid_format(self, model, bad_format):
    # Read one good record from a valid file.
    example = next(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))
    # Overwrite the image/format field to be an invalid value
    # (anything but 'raw').
    example.features.feature['image/format'].bytes_list.value[0] = bad_format
    source_path = test_utils.test_tmpfile('make_examples_output.tfrecord')
    io_utils.write_tfrecords([example], source_path)
    outfile = test_utils.test_tmpfile('call_variants_invalid_format.tfrecord')

    with self.assertRaises(ValueError):
      call_variants.call_variants(
          examples_filename=source_path,
          checkpoint_path=_LEAVE_MODEL_UNINITIALIZED,
          model=model,
          output_file=outfile,
          batch_size=1,
          max_batches=1,
          use_tpu=FLAGS.use_tpu)
Exemple #26
0
    def _call_end2end_helper(self, examples_path, model, shard_inputs):
        examples = list(io_utils.read_tfrecords(examples_path))

        if shard_inputs:
            # Create a sharded version of our golden examples.
            source_path = test_utils.test_tmpfile('sharded@{}'.format(3))
            io_utils.write_tfrecords(examples, source_path)
        else:
            source_path = examples_path

        # If we point the test at a headless server, it will often be 2x2,
        # which has 8 replicas.  Otherwise a smaller batch size is fine.
        if FLAGS.use_tpu:
            batch_size = 8
        else:
            batch_size = 4

        if model.name == 'random_guess':
            # For the random guess model we can run everything.
            max_batches = None
        else:
            # For all other models we only run a single batch for inference.
            max_batches = 1

        outfile = test_utils.test_tmpfile('call_variants.tfrecord')
        call_variants.call_variants(
            examples_filename=source_path,
            checkpoint_path=_LEAVE_MODEL_UNINITIALIZED,
            model=model,
            output_file=outfile,
            batch_size=batch_size,
            max_batches=max_batches,
            master='',
            use_tpu=FLAGS.use_tpu,
        )

        call_variants_outputs = list(
            io_utils.read_tfrecords(outfile,
                                    deepvariant_pb2.CallVariantsOutput))

        return call_variants_outputs, examples, batch_size, max_batches
  def test_prepare_inputs(self, filename, expand_to_file_pattern):
    source_path = test_utils.test_tmpfile(filename)
    io_utils.write_tfrecords(self.examples, source_path)
    if expand_to_file_pattern:
      # Transform foo@3 to foo-?????-of-00003.
      source_path = io_utils.NormalizeToShardedFilePattern(source_path)

    with self.test_session() as sess:
      _, variants, _ = call_variants.prepare_inputs(
          source_path, self.model, batch_size=1)
      sess.run(tf.local_variables_initializer())
      sess.run(tf.global_variables_initializer())

      seen_variants = []
      try:
        while True:
          seen_variants.extend(sess.run(variants))
      except tf.errors.OutOfRangeError:
        pass

      self.assertItemsEqual(self.variants,
                            variant_utils.decode_variants(seen_variants))
    def test_reading_sharded_dataset(self, compressed_inputs, use_tpu):
        golden_dataset = make_golden_dataset(compressed_inputs,
                                             use_tpu=use_tpu)
        n_shards = 3
        sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards))
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(golden_dataset.input_file_spec),
            sharded_path)

        config_file = _test_dataset_config(
            'test_sharded.pbtxt',
            name='sharded_test',
            tfrecord_path=sharded_path,
            num_examples=golden_dataset.num_examples)

        self.assertTfDataSetExamplesMatchExpected(
            data_providers.get_input_fn_from_dataset(
                config_file, mode=tf.estimator.ModeKeys.EVAL),
            golden_dataset,
            # workaround_list_files is needed because wildcards, and so sharded
            # files, are nondeterministicly ordered (for now).
            workaround_list_files=True,
        )
 def test_call_end2end_zero_record_file(self):
     zero_record_file = test_utils.test_tmpfile('zero_record_file')
     io_utils.write_tfrecords([], zero_record_file)
     self.assertCallVariantsEmitsNRecordsForRandomGuess(
         test_utils.test_tmpfile('zero_record_file'), 0)
Exemple #30
0
 def write_test_protos(self, filename):
   protos = [reference_pb2.ContigInfo(name=str(i)) for i in range(10)]
   path = test_utils.test_tmpfile(filename)
   io.write_tfrecords(protos, path)
   return protos, path
Exemple #31
0
 def test_call_end2end_zero_record_file_for_inception_v3(self):
     zero_record_file = test_utils.test_tmpfile('zero_record_file')
     io_utils.write_tfrecords([], zero_record_file)
     self.assertCallVariantsEmitsNRecordsForInceptionV3(
         test_utils.test_tmpfile('zero_record_file'), 0)
 def test_call_end2end_zero_record_file(self):
   zero_record_file = test_utils.test_tmpfile('zero_record_file')
   io_utils.write_tfrecords([], zero_record_file)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(
       test_utils.test_tmpfile('zero_record_file'), 0)
  def test_call_end2end(self, model, shard_inputs, include_debug_info):
    FLAGS.include_debug_info = include_debug_info
    examples = list(io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))

    if shard_inputs:
      # Create a sharded version of our golden examples.
      source_path = test_utils.test_tmpfile('sharded@{}'.format(3))
      io_utils.write_tfrecords(examples, source_path)
    else:
      source_path = testdata.GOLDEN_CALLING_EXAMPLES

    batch_size = 4
    if model.name == 'random_guess':
      # For the random guess model we can run everything.
      max_batches = None
    else:
      # For all other models we only run a single batch for inference.
      max_batches = 1

    outfile = test_utils.test_tmpfile('call_variants.tfrecord')
    call_variants.call_variants(
        examples_filename=source_path,
        checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
        model=model,
        output_file=outfile,
        batch_size=batch_size,
        max_batches=max_batches)

    call_variants_outputs = list(
        io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput))

    # Check that we have the right number of output protos.
    self.assertEqual(
        len(call_variants_outputs), batch_size * max_batches
        if max_batches else len(examples))

    # Check that our CallVariantsOutput (CVO) have the following critical
    # properties:
    # - we have one CVO for each example we processed.
    # - the variant in the CVO is exactly what was in the example.
    # - the alt_allele_indices of the CVO match those of its corresponding
    #   example.
    # - there are 3 genotype probabilities and these are between 0.0 and 1.0.
    # We can only do this test when processing all of the variants (max_batches
    # is None), since we processed all of the examples with that model.
    if max_batches is None:
      self.assertItemsEqual([cvo.variant for cvo in call_variants_outputs],
                            [tf_utils.example_variant(ex) for ex in examples])

    # Check the CVO debug_info: not filled if include_debug_info is False;
    # else, filled by logic based on CVO.
    if not include_debug_info:
      for cvo in call_variants_outputs:
        self.assertEqual(cvo.debug_info,
                         deepvariant_pb2.CallVariantsOutput.DebugInfo())
    else:
      for cvo in call_variants_outputs:
        self.assertEqual(cvo.debug_info.has_insertion,
                         variant_utils.has_insertion(cvo.variant))
        self.assertEqual(cvo.debug_info.has_deletion,
                         variant_utils.has_deletion(cvo.variant))
        self.assertEqual(cvo.debug_info.is_snp, variant_utils.is_snp(
            cvo.variant))
        self.assertEqual(cvo.debug_info.predicted_label,
                         np.argmax(cvo.genotype_probabilities))

    def example_matches_call_variants_output(example, call_variants_output):
      return (tf_utils.example_variant(example) == call_variants_output.variant
              and tf_utils.example_alt_alleles_indices(
                  example) == call_variants_output.alt_allele_indices.indices)

    for call_variants_output in call_variants_outputs:
      # Find all matching examples.
      matches = [
          ex for ex in examples
          if example_matches_call_variants_output(ex, call_variants_output)
      ]
      # We should have exactly one match.
      self.assertEqual(len(matches), 1)
      example = matches[0]
      # Check that we've faithfully copied in the alt alleles (though currently
      # as implemented we find our example using this information so it cannot
      # fail). Included here in case that changes in the future.
      self.assertEqual(
          list(tf_utils.example_alt_alleles_indices(example)),
          list(call_variants_output.alt_allele_indices.indices))
      # We should have exactly three genotype probabilities (assuming our
      # ploidy == 2).
      self.assertEqual(len(call_variants_output.genotype_probabilities), 3)
      # These are probabilities so they should be between 0 and 1.
      self.assertTrue(
          0 <= gp <= 1 for gp in call_variants_output.genotype_probabilities)
    def test_call_end2end(self, model, shard_inputs, include_debug_info):
        FLAGS.include_debug_info = include_debug_info
        examples = list(
            io_utils.read_tfrecords(testdata.GOLDEN_CALLING_EXAMPLES))

        if shard_inputs:
            # Create a sharded version of our golden examples.
            source_path = test_utils.test_tmpfile('sharded@{}'.format(3))
            io_utils.write_tfrecords(examples, source_path)
        else:
            source_path = testdata.GOLDEN_CALLING_EXAMPLES

        batch_size = 4
        if model.name == 'random_guess':
            # For the random guess model we can run everything.
            max_batches = None
        else:
            # For all other models we only run a single batch for inference.
            max_batches = 1

        outfile = test_utils.test_tmpfile('call_variants.tfrecord')
        call_variants.call_variants(
            examples_filename=source_path,
            checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
            model=model,
            output_file=outfile,
            batch_size=batch_size,
            max_batches=max_batches)

        call_variants_outputs = list(
            io_utils.read_tfrecords(outfile,
                                    deepvariant_pb2.CallVariantsOutput))

        # Check that we have the right number of output protos.
        self.assertEqual(
            len(call_variants_outputs),
            batch_size * max_batches if max_batches else len(examples))

        # Check that our CallVariantsOutput (CVO) have the following critical
        # properties:
        # - we have one CVO for each example we processed.
        # - the variant in the CVO is exactly what was in the example.
        # - the alt_allele_indices of the CVO match those of its corresponding
        #   example.
        # - there are 3 genotype probabilities and these are between 0.0 and 1.0.
        # We can only do this test when processing all of the variants (max_batches
        # is None), since we processed all of the examples with that model.
        if max_batches is None:
            self.assertItemsEqual(
                [cvo.variant for cvo in call_variants_outputs],
                [tf_utils.example_variant(ex) for ex in examples])

        # Check the CVO debug_info: not filled if include_debug_info is False;
        # else, filled by logic based on CVO.
        if not include_debug_info:
            for cvo in call_variants_outputs:
                self.assertEqual(
                    cvo.debug_info,
                    deepvariant_pb2.CallVariantsOutput.DebugInfo())
        else:
            for cvo in call_variants_outputs:
                self.assertEqual(cvo.debug_info.has_insertion,
                                 variant_utils.has_insertion(cvo.variant))
                self.assertEqual(cvo.debug_info.has_deletion,
                                 variant_utils.has_deletion(cvo.variant))
                self.assertEqual(cvo.debug_info.is_snp,
                                 variant_utils.is_snp(cvo.variant))
                self.assertEqual(cvo.debug_info.predicted_label,
                                 np.argmax(cvo.genotype_probabilities))

        def example_matches_call_variants_output(example,
                                                 call_variants_output):
            return (tf_utils.example_variant(example)
                    == call_variants_output.variant
                    and tf_utils.example_alt_alleles_indices(example)
                    == call_variants_output.alt_allele_indices.indices)

        for call_variants_output in call_variants_outputs:
            # Find all matching examples.
            matches = [
                ex for ex in examples if example_matches_call_variants_output(
                    ex, call_variants_output)
            ]
            # We should have exactly one match.
            self.assertEqual(len(matches), 1)
            example = matches[0]
            # Check that we've faithfully copied in the alt alleles (though currently
            # as implemented we find our example using this information so it cannot
            # fail). Included here in case that changes in the future.
            self.assertEqual(
                list(tf_utils.example_alt_alleles_indices(example)),
                list(call_variants_output.alt_allele_indices.indices))
            # We should have exactly three genotype probabilities (assuming our
            # ploidy == 2).
            self.assertEqual(len(call_variants_output.genotype_probabilities),
                             3)
            # These are probabilities so they should be between 0 and 1.
            self.assertTrue(
                0 <= gp <= 1
                for gp in call_variants_output.genotype_probabilities)
Exemple #35
0
def call_variants(examples_filename,
                  checkpoint_path,
                  model,
                  output_file,
                  execution_hardware='auto',
                  batch_size=16,
                  max_batches=None):
  """Main driver of call_variants."""
  # Read a single TFExample to make sure we're not loading an older version.
  example_format = tf_utils.get_format_from_examples_path(examples_filename)
  if example_format is None:
    logging.warning('Unable to read any records from %s. Output will contain '
                    'zero records.', examples_filename)
    io_utils.write_tfrecords([], output_file)
  elif example_format != 'raw':
    raise ValueError('The TF examples in {} has image/format \'{}\' '
                     '(expected \'raw\') which means you might need to rerun '
                     'make_examples to generate the examples again.'.format(
                         examples_filename, example_format))

  if execution_hardware not in _ALLOW_EXECUTION_HARDWARE:
    raise ValueError(
        'Unexpected execution_hardware={} value. Allowed values are {}'.format(
            execution_hardware, ','.join(_ALLOW_EXECUTION_HARDWARE)))

  with tf.Graph().as_default():
    images, encoded_variants, encoded_alt_allele_indices = prepare_inputs(
        examples_filename, model, batch_size, FLAGS.num_readers)

    # Create our model and extract the predictions from the model endpoints.
    predictions = model.create(images, 3, is_training=False)['Predictions']

    # The op for initializing the variables.
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    device_count = {'GPU': 0, 'TPU': 0} if execution_hardware == 'cpu' else {}
    config = tf.ConfigProto(device_count=device_count)
    with tf.Session(config=config) as sess:
      sess.run(init_op)

      # Initial the model from the provided checkpoint using our session.
      logging.info('Initializing model from %s', checkpoint_path)
      model.initialize_from_checkpoint(checkpoint_path, 3, False)(sess)

      if execution_hardware == 'accelerator':
        if not any(dev.device_type != 'CPU' for dev in sess.list_devices()):
          raise ExecutionHardwareError(
              'execution_hardware is set to accelerator, but no accelerator '
              'was found')

      logging.info('Writing calls to %s', output_file)
      writer, _ = io_utils.make_proto_writer(output_file)
      with writer:
        start_time = time.time()
        try:
          n_batches = 0
          n_examples = 0
          while max_batches is None or n_batches < max_batches:
            n_called = call_batch(sess, writer, encoded_variants,
                                  encoded_alt_allele_indices, predictions)

            duration = time.time() - start_time
            n_batches += 1
            n_examples += n_called
            logging.info(
                ('Processed %s examples in %s batches [%.2f sec per 100]'),
                n_examples, n_batches, (100 * duration) / n_examples)
        except tf.errors.OutOfRangeError:
          logging.info('Done evaluating variants')
Exemple #36
0
def call_variants(examples_filename,
                  checkpoint_path,
                  model,
                  output_file,
                  execution_hardware='auto',
                  batch_size=16,
                  max_batches=None,
                  use_tpu=False,
                  master=''):
    """Main driver of call_variants."""
    if FLAGS.kmp_blocktime:
        os.environ['KMP_BLOCKTIME'] = FLAGS.kmp_blocktime
        logging.info('Set KMP_BLOCKTIME to %s', os.environ['KMP_BLOCKTIME'])

    # Read a single TFExample to make sure we're not loading an older version.
    example_format = tf_utils.get_format_from_examples_path(examples_filename)
    if example_format is None:
        logging.warning(
            'Unable to read any records from %s. Output will contain '
            'zero records.', examples_filename)
        io_utils.write_tfrecords([], output_file)
        return
    elif example_format != 'raw':
        raise ValueError(
            'The TF examples in {} has image/format \'{}\' '
            '(expected \'raw\') which means you might need to rerun '
            'make_examples to generate the examples again.'.format(
                examples_filename, example_format))

    # Check accelerator status.
    if execution_hardware not in _ALLOW_EXECUTION_HARDWARE:
        raise ValueError(
            'Unexpected execution_hardware={} value. Allowed values are {}'.
            format(execution_hardware, ','.join(_ALLOW_EXECUTION_HARDWARE)))
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    device_count = {'GPU': 0, 'TPU': 0} if execution_hardware == 'cpu' else {}
    config = tf.ConfigProto(device_count=device_count)
    with tf.Session(config=config) as sess:
        sess.run(init_op)
        if execution_hardware == 'accelerator':
            if not any(dev.device_type != 'CPU'
                       for dev in sess.list_devices()):
                raise ExecutionHardwareError(
                    'execution_hardware is set to accelerator, but no accelerator '
                    'was found')
        # redacted
        # sess.list_devices here doesn't return the correct answer. That can only
        # work later, after the device (on the other VM) has been initialized,
        # which is generally not yet.

    # Prepare input stream and estimator.
    tf_dataset = prepare_inputs(source_path=examples_filename, use_tpu=use_tpu)
    estimator = model.make_estimator(
        batch_size=batch_size,
        master=master,
        use_tpu=use_tpu,
    )

    # Instantiate the prediction "stream", and select the EMA values from
    # the model.
    if checkpoint_path is None:
        # Unit tests use this branch.
        predict_hooks = []
    else:
        predict_hooks = [
            h(checkpoint_path) for h in model.session_predict_hooks()
        ]
    predictions = iter(
        estimator.predict(input_fn=tf_dataset,
                          checkpoint_path=checkpoint_path,
                          hooks=predict_hooks))

    # Consume predictions one at a time and write them to output_file.
    logging.info('Writing calls to %s', output_file)
    writer, _ = io_utils.make_proto_writer(output_file)
    with writer:
        start_time = time.time()
        n_examples, n_batches = 0, 0
        while max_batches is None or n_batches <= max_batches:
            try:
                prediction = next(predictions)
            except (StopIteration, tf.errors.OutOfRangeError):
                break
            write_variant_call(writer, prediction, use_tpu)
            n_examples += 1
            n_batches = n_examples // batch_size + 1
            duration = time.time() - start_time

            logging.log_every_n(
                logging.INFO,
                ('Processed %s examples in %s batches [%.3f sec per 100]'),
                _LOG_EVERY_N, n_examples, n_batches,
                (100 * duration) / n_examples)

        logging.info('Done evaluating variants')
Exemple #37
0
def call_variants(examples_filename,
                  checkpoint_path,
                  model,
                  output_file,
                  execution_hardware='auto',
                  batch_size=16,
                  max_batches=None):
  """Main driver of call_variants."""
  # Read a single TFExample to make sure we're not loading an older version.
  example_format = tf_utils.get_format_from_examples_path(examples_filename)
  if example_format is None:
    logging.warning('Unable to read any records from %s. Output will contain '
                    'zero records.', examples_filename)
    io_utils.write_tfrecords([], output_file)
    return
  elif example_format != 'raw':
    raise ValueError('The TF examples in {} has image/format \'{}\' '
                     '(expected \'raw\') which means you might need to rerun '
                     'make_examples to generate the examples again.'.format(
                         examples_filename, example_format))

  if execution_hardware not in _ALLOW_EXECUTION_HARDWARE:
    raise ValueError(
        'Unexpected execution_hardware={} value. Allowed values are {}'.format(
            execution_hardware, ','.join(_ALLOW_EXECUTION_HARDWARE)))

  with tf.Graph().as_default():
    images, encoded_variants, encoded_alt_allele_indices = prepare_inputs(
        examples_filename, model, batch_size, FLAGS.num_readers)

    # Create our model and extract the predictions from the model endpoints.
    predictions = model.create(images, 2, is_training=False)['Predictions']

    # The op for initializing the variables.
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    device_count = {'GPU': 0, 'TPU': 0} if execution_hardware == 'cpu' else {}
    config = tf.ConfigProto(device_count=device_count)
    with tf.Session(config=config) as sess:
      sess.run(init_op)

      # Initial the model from the provided checkpoint using our session.
      logging.info('Initializing model from %s', checkpoint_path)
      model.initialize_from_checkpoint(checkpoint_path, 2, False)(sess)

      if execution_hardware == 'accelerator':
        if not any(dev.device_type != 'CPU' for dev in sess.list_devices()):
          raise ExecutionHardwareError(
              'execution_hardware is set to accelerator, but no accelerator '
              'was found')

      logging.info('Writing calls to %s', output_file)
      writer, _ = io_utils.make_proto_writer(output_file)
      with writer:
        start_time = time.time()
        try:
          n_batches = 0
          n_examples = 0
          while max_batches is None or n_batches < max_batches:
            n_called = call_batch(sess, writer, encoded_variants,
                                  encoded_alt_allele_indices, predictions)

            duration = time.time() - start_time
            n_batches += 1
            n_examples += n_called
            logging.info(
                ('Processed %s examples in %s batches [%.2f sec per 100]'),
                n_examples, n_batches, (100 * duration) / n_examples)
        except tf.errors.OutOfRangeError:
          logging.info('Done evaluating variants')
 def test_call_variants_with_empty_input(self):
   source_path = test_utils.test_tmpfile('empty.tfrecord')
   io_utils.write_tfrecords([], source_path)
   # Make sure that prepare_inputs don't crash on empty input.
   call_variants.prepare_inputs(
       source_path, modeling.get_model('random_guess'), batch_size=1)