示例#1
0
    def test_prepare_inputs(self, filename_to_write, file_string_input):
        source_path = test_utils.test_tmpfile(filename_to_write)
        tfrecord.write_tfrecords(self.examples, source_path)
        # file_string_input could be a comma-separated list. Add the prefix to all
        # of them, and join it back to a string.
        file_string_input = ','.join(
            [test_utils.test_tmpfile(f) for f in file_string_input.split(',')])

        with self.test_session() as sess:
            sess.run(tf.compat.v1.local_variables_initializer())
            sess.run(tf.compat.v1.global_variables_initializer())

            ds = call_variants.prepare_inputs(file_string_input)
            _, variants, _ = _get_infer_batches(ds,
                                                model=self.model,
                                                batch_size=1)

            seen_variants = []
            try:
                while True:
                    seen_variants.extend(sess.run(variants))
            except tf.errors.OutOfRangeError:
                pass

            six.assertCountEqual(self, self.variants,
                                 variant_utils.decode_variants(seen_variants))
    def test_get_batches(self, compressed_inputs, mode, use_tpu):
        mode = (tf.estimator.ModeKeys.EVAL
                if mode == 'EVAL' else tf.estimator.ModeKeys.TRAIN)
        input_fn = make_golden_dataset(compressed_inputs,
                                       mode=mode,
                                       use_tpu=use_tpu)
        batch_size = 16
        with tf.Session() as sess:
            batch = input_fn(dict(
                batch_size=batch_size)).make_one_shot_iterator().get_next()

            # Get our images, labels, and variants for further testing.
            sess.run(tf.global_variables_initializer())
            features, labels = sess.run(batch)
            variants = features['variant']
            images = features['image']

            # Checks that our labels are the right shape and are one-hot encoded.
            # Note that the shape is 100, not 107, because we only adjust the image
            # in the model_fn now, where previously it was done in the input_fn.
            self.assertEqual([batch_size] + dv_constants.PILEUP_DEFAULT_DIMS,
                             list(images.shape))
            self.assertEqual((batch_size, ), labels.shape)
            for label in labels:
                # pylint: disable=g-generic-assert
                self.assertTrue(0 <= label < dv_constants.NUM_CLASSES)

            # Check that our variants has the shape we expect and actually contain
            # variants by decoding them and checking the reference_name.
            self.assertEqual(batch_size, variants.shape[0])
            for variant in variants:
                if use_tpu:
                    variant = tf_utils.int_tensor_to_string(variant)
                for v in variant_utils.decode_variants([variant]):
                    self.assertEqual(v.reference_name, 'chr20')
示例#3
0
    def test_prepare_inputs(self, filename, expand_to_file_pattern):
        source_path = test_utils.test_tmpfile(filename)
        io_utils.write_tfrecords(self.examples, source_path)
        if expand_to_file_pattern:
            # Transform foo@3 to foo-?????-of-00003.
            source_path = io_utils.NormalizeToShardedFilePattern(source_path)

        with self.test_session() as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            ds = call_variants.prepare_inputs(source_path)
            _, variants, _ = data_providers.get_infer_batches(ds,
                                                              model=self.model,
                                                              batch_size=1)

            seen_variants = []
            try:
                while True:
                    seen_variants.extend(sess.run(variants))
            except tf.errors.OutOfRangeError:
                pass

            self.assertItemsEqual(self.variants,
                                  variant_utils.decode_variants(seen_variants))
示例#4
0
 def test_decode_variants(self):
   variants = [
       test_utils.make_variant(start=1),
       test_utils.make_variant(start=2)
   ]
   encoded = [variant.SerializeToString() for variant in variants]
   actual = variant_utils.decode_variants(encoded)
   # We have an iterable, so actual isn't equal to variants.
   self.assertNotEqual(actual, variants)
   # Making actual a list now makes it equal.
   self.assertEqual(list(actual), variants)
示例#5
0
 def test_decode_variants(self):
   variants = [
       test_utils.make_variant(start=1),
       test_utils.make_variant(start=2)
   ]
   encoded = [variant.SerializeToString() for variant in variants]
   actual = variant_utils.decode_variants(encoded)
   # We have an iterable, so actual isn't equal to variants.
   self.assertNotEqual(actual, variants)
   # Making actual a list now makes it equal.
   self.assertEqual(list(actual), variants)
示例#6
0
    def test_get_batches(self, compressed_inputs, mode):
        golden_dataset = make_golden_dataset(compressed_inputs)
        batch_size = 16
        with tf.Session() as sess:
            mock_model = mock.MagicMock(autospec=modeling.DeepVariantModel)
            mock_model.preprocess_image.side_effect = functools.partial(
                tf.image.resize_image_with_crop_or_pad,
                target_height=107,
                target_width=221)
            batch = data_providers.make_batches(
                golden_dataset.get_slim_dataset(),
                mock_model,
                batch_size,
                mode=mode)

            # We should have called our preprocess_image exactly once. We don't have
            # the actual objects to test for the call, though.
            test_utils.assert_called_once_workaround(
                mock_model.preprocess_image)

            # Get our images, labels, and variants for further testing.
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord, sess=sess)
            images, labels, variants = sess.run(batch)

            # Checks that our labels are the right shape and are one-hot encoded.
            self.assertEqual(
                (batch_size, 107, 221, pileup_image.DEFAULT_NUM_CHANNEL),
                images.shape)
            self.assertEqual((batch_size, ), labels.shape)
            for label in labels:
                self.assertTrue(0 <= label <= 2)

            # Check that our variants has the shape we expect and actually contain
            # variants by decoding them and checking the reference_name.
            self.assertEqual((batch_size, ), variants.shape)
            for variant in variant_utils.decode_variants(variants):
                self.assertEqual(variant.reference_name, 'chr20')

            # Shutdown tensorflow
            coord.request_stop()
            coord.join(threads)
示例#7
0
  def test_prepare_inputs(self, filename, expand_to_file_pattern):
    source_path = test_utils.test_tmpfile(filename)
    io_utils.write_tfrecords(self.examples, source_path)
    if expand_to_file_pattern:
      # Transform foo@3 to foo-?????-of-00003.
      source_path = io_utils.NormalizeToShardedFilePattern(source_path)

    with self.test_session() as sess:
      _, variants, _ = call_variants.prepare_inputs(
          source_path, self.model, batch_size=1)
      sess.run(tf.local_variables_initializer())
      sess.run(tf.global_variables_initializer())

      seen_variants = []
      try:
        while True:
          seen_variants.extend(sess.run(variants))
      except tf.errors.OutOfRangeError:
        pass

      self.assertItemsEqual(self.variants,
                            variant_utils.decode_variants(seen_variants))
示例#8
0
  def test_get_batches(self, compressed_inputs, mode):
    golden_dataset = make_golden_dataset(compressed_inputs)
    batch_size = 16
    with tf.Session() as sess:
      mock_model = mock.MagicMock(autospec=modeling.DeepVariantModel)
      mock_model.preprocess_image.side_effect = functools.partial(
          tf.image.resize_image_with_crop_or_pad,
          target_height=107,
          target_width=221)
      batch = data_providers.make_batches(
          golden_dataset.get_slim_dataset(), mock_model, batch_size, mode=mode)

      # We should have called our preprocess_image exactly once. We don't have
      # the actual objects to test for the call, though.
      test_utils.assert_called_once_workaround(mock_model.preprocess_image)

      # Get our images, labels, and variants for further testing.
      sess.run(tf.global_variables_initializer())
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(coord=coord, sess=sess)
      images, labels, variants = sess.run(batch)

      # Checks that our labels are the right shape and are one-hot encoded.
      self.assertEqual((batch_size, 107, 221, pileup_image.DEFAULT_NUM_CHANNEL),
                       images.shape)
      self.assertEqual((batch_size,), labels.shape)
      for label in labels:
        self.assertTrue(0 <= label <= 2)

      # Check that our variants has the shape we expect and actually contain
      # variants by decoding them and checking the reference_name.
      self.assertEqual((batch_size,), variants.shape)
      for variant in variant_utils.decode_variants(variants):
        self.assertEqual(variant.reference_name, 'chr20')

      # Shutdown tensorflow
      coord.request_stop()
      coord.join(threads)
示例#9
0
 def _select(encoded_variants):
     weights = [
         1.0 * variant_p_func(variant)
         for variant in variant_utils.decode_variants(encoded_variants)
     ]
     return np.array(weights, dtype=np.float32)
示例#10
0
 def _select(encoded_variants):
   weights = [
       1.0 * variant_p_func(variant)
       for variant in variant_utils.decode_variants(encoded_variants)
   ]
   return np.array(weights, dtype=np.float32)