示例#1
0
    def test_exceptions_are_propogated(self, exception, text):
        write_fn = mock.MagicMock(side_effect=exception(text))

        # Just creating the writer doesn't actually trigger the exception, since
        # that only occurs in write_fn, which isn't run until we are in the
        # queue.
        writer = io.AsyncWriter(write_fn, maxsize=1)

        # We can write the writer, without being in the with block, and we don't
        # except out.
        writer.write(1)

        # We've already got 1 in the queue, so as soon as the thread starts it'll
        # except out. This code, with just a pass in the body, tests that we raise
        # an exception properly from the __exit__ function.
        with self.assertRaisesRegexp(exception, text):
            with writer:
                pass

        # The thread is no longer running, but the exception was already triggered.
        # So we want to make sure that a new writer.write call triggers the
        # exception again. This isn't a natural test, in the sense that we really
        # want a producer thread to get the exception from a write call while the
        # thread is actually running, but this is close enough without doing locking
        # etc gymnastics just for testing.
        with self.assertRaisesRegexp(exception, text):
            writer.write(2)

        # We already raised the exception to the producer via writer.write(2) call
        # above. This means that our next use of writer as a context manager should
        # not signal another exception from the __exit__ function.
        with writer:
            pass
示例#2
0
    def test_exception_from_flush_is_thrown_to_producer(self, exception, text):
        """Tests that we get an exception thrown by flush_fn."""
        write_fn = mock.MagicMock()
        flush_fn = mock.MagicMock(side_effect=exception(text))

        with self.assertRaisesRegexp(exception, text):
            with io.AsyncWriter(write_fn, flush_fn=flush_fn):
                pass
示例#3
0
    def test_flush_is_called_when_provided(self):
        write_fn = mock.MagicMock()
        flush_fn = mock.MagicMock()

        with io.AsyncWriter(write_fn, flush_fn=flush_fn, maxsize=1) as writer:
            writer.write(1)
            self.assertFalse(flush_fn.called)

        write_fn.assert_called_once_with(1)
        flush_fn.assert_called_once_with()
示例#4
0
    def test_writing(self, objs):
        # Make sure everything is a tuple for the *args of write.
        objs = [(obj, ) if not isinstance(obj, tuple) else obj for obj in objs]

        # Create a writer and write each of objs into it.
        write_fn = mock.MagicMock()
        with io.AsyncWriter(write_fn, maxsize=100) as writer:
            for obj in objs:
                writer.write(*obj)

        # Tests that the calls to our write_fn is exactly those we expect in their
        # exact order.
        self.assertEqual([mock.call(*obj) for obj in objs],
                         write_fn.call_args_list)
示例#5
0
    def test_thread_state(self):
        writer = io.AsyncWriter(lambda _: 0, maxsize=100)

        # Thread isn't started at creation.
        self.assertIsNone(writer._thread)

        with writer:
            # Once we enter the with block, our thread should exist and be flagged as
            # a daemon thread.
            self.assertIsNotNone(writer._thread)
            self.assertTrue(writer._thread.daemon)

        # Thread is None.
        self.assertIsNone(writer._thread)
示例#6
0
    def test_exception_with_non_empty_queue_and_exit(self):
        # This test is for a very specific bug I observed using this code. If an
        # exception occurs during write, and we have elements in the queue, the
        # previous implement would hang when trying to exit the with block. This
        # checks for that exact situation.
        write_fn = mock.MagicMock(side_effect=KeyError('write_fn'))
        writer = io.AsyncWriter(write_fn)
        writer.write(1)
        writer.write(2)

        # There are now two element in the queue. We'll remove one and throw an
        # exception, leaving one element in the queue and the exception. We should
        # still except out immediately in the with block, rather than hang.
        with self.assertRaisesRegexp(KeyError, 'write_fn'):
            with writer:
                pass
def make_async_writer(write_fn):
    """Creates an AsyncWriter writing CallVariantsOutput to write_fn.

  The output CallVariantsOutput proto contains fields that are used to
  sort and merge multi-allelic after the call_variants step. It contains the
  serialized Variant proto in the 'variant/encoded' field.

  The created AsyncWriter has a write() function accepting 3 arguments,
  a CallVariantsOutput (which contains the serialized
  learning.genomics.v1.Variant proto), a vector of genotype_probabilities,
  and the alt_allele_indices that describes the biallelics used in the
  computation by our deep learning model, one for each genotype state
  of variant.  This write method invokes add_call_to_variant, on the
  decoded variant to add the call information, and then calls write_fn
  on this variant.

  Args:
    write_fn: A function accepting a CallVariantsOutput proto that
    writes to its underlying writer.

  Returns:
    An AsyncWriter.
  """
    def write_output(encoded_variant, gls, encoded_alt_allele_indices):
        """Provides a write function for a CallVariantsOutput proto."""
        variant = variants_pb2.Variant.FromString(encoded_variant)
        alt_allele_indices = (
            deepvariant_pb2.CallVariantsOutput.AltAlleleIndices.FromString(
                encoded_alt_allele_indices))
        debug_info = None
        if FLAGS.include_debug_info:
            debug_info = deepvariant_pb2.CallVariantsOutput.DebugInfo(
                has_insertion=variantutils.has_insertion(variant),
                has_deletion=variantutils.has_deletion(variant),
                is_snp=variantutils.is_snp(variant),
                predicted_label=np.argmax(gls))
        call_variants_output = deepvariant_pb2.CallVariantsOutput(
            variant=variant,
            alt_allele_indices=alt_allele_indices,
            genotype_probabilities=gls,
            debug_info=debug_info)
        write_fn(call_variants_output)

    return io_utils.AsyncWriter(write_output)
示例#8
0
    def test_write_blocking_and_timeout(self):
        write_fn = mock.MagicMock()
        writer = io.AsyncWriter(write_fn, maxsize=1)

        # Queue is empty, so we succeed immediately.
        writer.write(1)

        # With a non-blocking write we'll immediately fail with a Full exception.
        with self.assertRaises(Queue.Full):
            writer.write(2, block=False)

        # We block and still get the Full exception. This test will fail if the
        # timeout isn't respected as we'll time out in the unittest itself.
        with self.assertRaises(Queue.Full):
            writer.write(3, block=True, timeout=1)

        # This construct starts the writer thread, so after we exit the with block
        # we'll have written all of the Queue elements.
        with writer:
            pass
        write_fn.assert_called_once_with(1)
def write_call_variants_output_to_vcf(contigs, input_sorted_tfrecord_path,
                                      output_vcf_path, qual_filter,
                                      multi_allelic_qual_filter, sample_name):
    """Reads CallVariantsOutput protos and writes to a VCF file.

  Variants present in the input TFRecord are converted to VCF format, with the
  following filters applied: 1) variants are omitted if their quality is lower
  than the `qual_filter` threshold. 2) multi-allelic variants omit individual
  alleles whose qualities are lower than the `multi_allelic_qual_filter`
  threshold.

  Args:
    contigs: list(ContigInfo). A list of the reference genome contigs for
      writers that need contig information.
    input_sorted_tfrecord_path: str. TFRecord format file containing sorted
      CallVariantsOutput protos.
    output_vcf_path: str. Output file in VCF format.
    qual_filter: double. The qual value below which to filter variants.
    multi_allelic_qual_filter: double. The qual value below which to filter
      multi-allelic variants.
    sample_name: str. Sample name to write to VCF file.
  """
    logging.info('Writing calls to VCF file: %s', output_vcf_path)
    sync_writer, writer_fn = genomics_io.make_variant_writer(
        output_vcf_path, contigs, samples=[sample_name], filters=FILTERS)
    with sync_writer, io_utils.AsyncWriter(writer_fn) as writer:
        for _, group in itertools.groupby(
                io_utils.read_tfrecords(
                    input_sorted_tfrecord_path,
                    proto=deepvariant_pb2.CallVariantsOutput),
                lambda x: variantutils.variant_range(x.variant)):
            outputs = list(group)
            canonical_variant, predictions = merge_predictions(
                outputs, multi_allelic_qual_filter)
            variant = add_call_to_variant(canonical_variant,
                                          predictions,
                                          qual_filter=qual_filter,
                                          sample_name=sample_name)
            writer.write(variant)