Ejemplo n.º 1
0
 def test_decode_variants(self):
     variants = [
         test_utils.make_variant(start=1),
         test_utils.make_variant(start=2)
     ]
     encoded = [variant.SerializeToString() for variant in variants]
     actual = variantutils.decode_variants(encoded)
     # We have an iterable, so actual isn't equal to variants.
     self.assertNotEqual(actual, variants)
     # Making actual a list now makes it equal.
     self.assertEqual(list(actual), variants)
Ejemplo n.º 2
0
    def test_get_training_batches(self, compressed_inputs):
        golden_dataset = make_golden_dataset(compressed_inputs)
        batch_size = 16
        with tf.Session() as sess:
            mock_model = mock.MagicMock(autospec=modeling.DeepVariantModel)
            mock_model.preprocess_image.side_effect = functools.partial(
                tf.image.resize_image_with_crop_or_pad,
                target_height=107,
                target_width=221)
            batch = data_providers.make_training_batches(
                golden_dataset.get_slim_dataset(), mock_model, batch_size)

            # We should have called our preprocess_image exactly once. We don't have
            # the actual objects to test for the call, though.
            test_utils.assert_called_once_workaround(
                mock_model.preprocess_image)

            # Get our images, labels, and variants for further testing.
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord, sess=sess)
            images, labels, variants = sess.run(batch)

            # Checks that our labels are the right shape and are one-hot encoded.
            self.assertEqual(
                (batch_size, 107, 221, pileup_image.DEFAULT_NUM_CHANNEL),
                images.shape)
            self.assertEqual((batch_size, ), labels.shape)
            for label in labels:
                self.assertTrue(0 <= label <= 2)

            # Check that our variants has the shape we expect and actually contain
            # variants by decoding them and checking the reference_name.
            self.assertEqual((batch_size, ), variants.shape)
            for variant in variantutils.decode_variants(variants):
                self.assertEqual(variant.reference_name, 'chr20')

            # Shutdown tensorflow
            coord.request_stop()
            coord.join(threads)
  def test_prepare_inputs(self, filename, expand_to_file_pattern):
    source_path = test_utils.test_tmpfile(filename)
    io_utils.write_tfrecords(self.examples, source_path)
    if expand_to_file_pattern:
      # Transform foo@3 to foo-?????-of-00003.
      source_path = io_utils.NormalizeToShardedFilePattern(source_path)

    with self.test_session() as sess:
      _, variants, _ = call_variants.prepare_inputs(
          source_path, self.model, batch_size=1)
      sess.run(tf.local_variables_initializer())
      sess.run(tf.global_variables_initializer())

      seen_variants = []
      try:
        while True:
          seen_variants.extend(sess.run(variants))
      except tf.errors.OutOfRangeError:
        pass

      self.assertItemsEqual(self.variants,
                            variantutils.decode_variants(seen_variants))
Ejemplo n.º 4
0
 def _select(encoded_variants):
     weights = [
         1.0 * variant_p_func(variant)
         for variant in variantutils.decode_variants(encoded_variants)
     ]
     return np.array(weights, dtype=np.float32)