예제 #1
0
 def testGetNoneShapeFromEmptyExamplesPath(self, file_name_to_write,
                                           tfrecord_path_to_match):
     output_file = test_utils.test_tmpfile(file_name_to_write)
     io_utils.write_tfrecords([], output_file)
     self.assertIsNone(
         tf_utils.get_shape_from_examples_path(
             test_utils.test_tmpfile(tfrecord_path_to_match)))
예제 #2
0
 def test_call_variants_with_empty_input(self):
     source_path = test_utils.test_tmpfile('empty.tfrecord')
     io_utils.write_tfrecords([], source_path)
     # Make sure that prepare_inputs don't crash on empty input.
     call_variants.prepare_inputs(source_path,
                                  modeling.get_model('random_guess'),
                                  batch_size=1)
예제 #3
0
 def testGetShapeFromExamplesPath(self, file_name_to_write,
                                  tfrecord_path_to_match):
     example = example_pb2.Example()
     valid_shape = [1, 2, 3]
     example.features.feature['image/shape'].int64_list.value.extend(
         valid_shape)
     output_file = test_utils.test_tmpfile(file_name_to_write)
     io_utils.write_tfrecords([example], output_file)
     tf_utils.get_shape_from_examples_path(
         test_utils.test_tmpfile(tfrecord_path_to_match))
 def test_call_end2end_with_empty_shards(self):
   # Get only up to 10 examples.
   examples = list(
       io_utils.read_tfrecords(
           test_utils.GOLDEN_CALLING_EXAMPLES, max_records=10))
   # Write to 15 shards, which means there will be multiple empty shards.
   source_path = test_utils.test_tmpfile('sharded@{}'.format(15))
   io_utils.write_tfrecords(examples, source_path)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(source_path,
                                                      len(examples))
 def test_call_end2end_empty_first_shard(self):
   # Get only up to 10 examples.
   examples = list(
       io_utils.read_tfrecords(
           test_utils.GOLDEN_CALLING_EXAMPLES, max_records=10))
   empty_first_file = test_utils.test_tmpfile('empty_1st_shard-00000-of-00002')
   io_utils.write_tfrecords([], empty_first_file)
   second_file = test_utils.test_tmpfile('empty_1st_shard-00001-of-00002')
   io_utils.write_tfrecords(examples, second_file)
   self.assertCallVariantsEmitsNRecordsForRandomGuess(
       test_utils.test_tmpfile('empty_1st_shard@2'), len(examples))
 def test_call_variants_with_no_shape(self, model):
   # Read one good record from a valid file.
   example = next(io_utils.read_tfrecords(test_utils.GOLDEN_CALLING_EXAMPLES))
   # Remove image/shape.
   del example.features.feature['image/shape']
   source_path = test_utils.test_tmpfile('make_examples_out_noshape.tfrecord')
   io_utils.write_tfrecords([example], source_path)
   with self.assertRaisesRegexp(
       ValueError, 'Invalid image/shape: we expect to find an image/shape '
       'field with length 3.'):
     call_variants.prepare_inputs(source_path, model, batch_size=1)
예제 #7
0
def make_golden_dataset(compressed_inputs=False):
    if compressed_inputs:
        source_path = test_utils.test_tmpfile(
            'golden.postprocess_single_site_input.tfrecord.gz')
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(test_utils.GOLDEN_POSTPROCESS_INPUT,
                                    proto=deepvariant_pb2.CallVariantsOutput),
            source_path)
    else:
        source_path = test_utils.GOLDEN_POSTPROCESS_INPUT
    return source_path
def make_golden_dataset(compressed_inputs=False):
    if compressed_inputs:
        source_path = test_utils.test_tmpfile(
            'make_golden_dataset.tfrecord.gz')
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(test_utils.GOLDEN_TRAINING_EXAMPLES),
            source_path)
    else:
        source_path = test_utils.GOLDEN_TRAINING_EXAMPLES
    return data_providers.DeepVariantDataSet(name='labeled_golden',
                                             source=source_path,
                                             num_examples=49)
 def test_get_shape_from_examples_path(self, file_name_to_write,
                                       tfrecord_path_to_match):
     example = example_pb2.Example()
     valid_shape = [1, 2, 3]
     example.features.feature['image/shape'].int64_list.value.extend(
         valid_shape)
     output_file = test_utils.test_tmpfile(file_name_to_write)
     io_utils.write_tfrecords([example], output_file)
     ds = data_providers.DeepVariantDataSet(
         name='test_shape',
         source=test_utils.test_tmpfile(tfrecord_path_to_match),
         num_examples=1)
     self.assertEqual(valid_shape, ds.tensor_shape)
    def test_reading_sharded_dataset(self, compressed_inputs):
        golden_dataset = make_golden_dataset(compressed_inputs)
        n_shards = 3
        sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards))
        io_utils.write_tfrecords(
            io_utils.read_tfrecords(golden_dataset.source), sharded_path)

        config_file = _test_dataset_config(
            'test_sharded.pbtxt',
            name='sharded_test',
            tfrecord_path=sharded_path,
            num_examples=golden_dataset.num_examples)

        self.assertDataSetExamplesMatchExpected(
            data_providers.get_dataset(config_file).get_slim_dataset(),
            golden_dataset)
  def test_call_variants_with_invalid_format(self, model, bad_format):
    # Read one good record from a valid file.
    example = next(io_utils.read_tfrecords(test_utils.GOLDEN_CALLING_EXAMPLES))
    # Overwrite the image/format field to be an invalid value
    # (anything but 'raw').
    example.features.feature['image/format'].bytes_list.value[0] = bad_format
    source_path = test_utils.test_tmpfile('make_examples_output.tfrecord')
    io_utils.write_tfrecords([example], source_path)
    outfile = test_utils.test_tmpfile('call_variants_invalid_format.tfrecord')

    with self.assertRaises(ValueError):
      call_variants.call_variants(
          examples_filename=source_path,
          checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
          model=model,
          output_file=outfile,
          batch_size=1,
          max_batches=1)
  def test_prepare_inputs(self, filename, expand_to_file_pattern):
    source_path = test_utils.test_tmpfile(filename)
    io_utils.write_tfrecords(self.examples, source_path)
    if expand_to_file_pattern:
      # Transform foo@3 to foo-?????-of-00003.
      source_path = io_utils.NormalizeToShardedFilePattern(source_path)

    with self.test_session() as sess:
      _, variants, _ = call_variants.prepare_inputs(
          source_path, self.model, batch_size=1)
      sess.run(tf.local_variables_initializer())
      sess.run(tf.global_variables_initializer())

      seen_variants = []
      try:
        while True:
          seen_variants.extend(sess.run(variants))
      except tf.errors.OutOfRangeError:
        pass

      self.assertItemsEqual(self.variants,
                            variantutils.decode_variants(seen_variants))
  def test_call_end2end(self, model, shard_inputs, include_debug_info):
    FLAGS.include_debug_info = include_debug_info
    examples = list(io_utils.read_tfrecords(test_utils.GOLDEN_CALLING_EXAMPLES))

    if shard_inputs:
      # Create a sharded version of our golden examples.
      source_path = test_utils.test_tmpfile('sharded@{}'.format(3))
      io_utils.write_tfrecords(examples, source_path)
    else:
      source_path = test_utils.GOLDEN_CALLING_EXAMPLES

    batch_size = 4
    if model.name == 'random_guess':
      # For the random guess model we can run everything.
      max_batches = None
    else:
      # For all other models we only run a single batch for inference.
      max_batches = 1

    outfile = test_utils.test_tmpfile('call_variants.tfrecord')
    call_variants.call_variants(
        examples_filename=source_path,
        checkpoint_path=modeling.SKIP_MODEL_INITIALIZATION_IN_TEST,
        model=model,
        output_file=outfile,
        batch_size=batch_size,
        max_batches=max_batches)

    call_variants_outputs = list(
        io_utils.read_tfrecords(outfile, deepvariant_pb2.CallVariantsOutput))

    # Check that we have the right number of output protos.
    self.assertEqual(
        len(call_variants_outputs), batch_size * max_batches
        if max_batches else len(examples))

    # Check that our CallVariantsOutput (CVO) have the following critical
    # properties:
    # - we have one CVO for each example we processed.
    # - the variant in the CVO is exactly what was in the example.
    # - the alt_allele_indices of the CVO match those of its corresponding
    #   example.
    # - there are 3 genotype probabilities and these are between 0.0 and 1.0.
    # We can only do this test when processing all of the variants (max_batches
    # is None), since we processed all of the examples with that model.
    if max_batches is None:
      self.assertItemsEqual([cvo.variant for cvo in call_variants_outputs],
                            [tf_utils.example_variant(ex) for ex in examples])

    # Check the CVO debug_info: not filled if include_debug_info is False;
    # else, filled by logic based on CVO.
    if not include_debug_info:
      for cvo in call_variants_outputs:
        self.assertEqual(cvo.debug_info,
                         deepvariant_pb2.CallVariantsOutput.DebugInfo())
    else:
      for cvo in call_variants_outputs:
        self.assertEqual(cvo.debug_info.has_insertion,
                         variantutils.has_insertion(cvo.variant))
        self.assertEqual(cvo.debug_info.has_deletion,
                         variantutils.has_deletion(cvo.variant))
        self.assertEqual(cvo.debug_info.is_snp, variantutils.is_snp(
            cvo.variant))
        self.assertEqual(cvo.debug_info.predicted_label,
                         np.argmax(cvo.genotype_probabilities))

    def example_matches_call_variants_output(example, call_variants_output):
      return (tf_utils.example_variant(example) == call_variants_output.variant
              and tf_utils.example_alt_alleles_indices(
                  example) == call_variants_output.alt_allele_indices.indices)

    for call_variants_output in call_variants_outputs:
      # Find all matching examples.
      matches = [
          ex for ex in examples
          if example_matches_call_variants_output(ex, call_variants_output)
      ]
      # We should have exactly one match.
      self.assertEqual(len(matches), 1)
      example = matches[0]
      # Check that we've faithfully copied in the alt alleles (though currently
      # as implemented we find our example using this information so it cannot
      # fail). Included here in case that changes in the future.
      self.assertEqual(
          list(tf_utils.example_alt_alleles_indices(example)),
          list(call_variants_output.alt_allele_indices.indices))
      # We should have exactly three genotype probabilities (assuming our
      # ploidy == 2).
      self.assertEqual(len(call_variants_output.genotype_probabilities), 3)
      # These are probabilities so they should be between 0 and 1.
      self.assertTrue(
          0 <= gp <= 1 for gp in call_variants_output.genotype_probabilities)
예제 #14
0
 def write_test_protos(self, filename):
     protos = [core_pb2.ContigInfo(name=str(i)) for i in range(10)]
     path = test_utils.test_tmpfile(filename)
     io.write_tfrecords(protos, path)
     return protos, path