Exemplo n.º 1
0
    def test_shard_sorted_tfrecords_max_records(self, filename, max_records):
        protos, path = self.write_test_protos(filename)

        if max_records is None:
            expected_n = len(protos)
        else:
            expected_n = min(max_records, len(protos))
        # Create our generator of records from read_tfrecords.
        actual = io.read_shard_sorted_tfrecords(path,
                                                key=lambda x: int(x.name),
                                                proto=reference_pb2.ContigInfo,
                                                max_records=max_records)
        self.assertLen(list(actual), expected_n)
Exemplo n.º 2
0
  def test_shard_sorted_tfrecords_max_records(self, filename, max_records):
    protos, path = self.write_test_protos(filename)

    if max_records is None:
      expected_n = len(protos)
    else:
      expected_n = min(max_records, len(protos))
    # Create our generator of records from read_tfrecords.
    actual = io.read_shard_sorted_tfrecords(
        path,
        key=lambda x: int(x.name),
        proto=reference_pb2.ContigInfo,
        max_records=max_records)
    self.assertLen(list(actual), expected_n)
Exemplo n.º 3
0
  def test_shard_sorted_tfrecords(self, filename):
    protos, path = self.write_test_protos(filename)

    # Create our generator of records.
    key = lambda x: int(x.name)
    reader = io.read_shard_sorted_tfrecords(
        path, key=key, proto=reference_pb2.ContigInfo)

    # Make sure it's actually a generator.
    self.assertEqual(type(reader), types.GeneratorType)

    # Check the round-trip contents.
    contents = list(reader)
    self.assertEqual(protos, contents)
    self.assertEqual(contents, sorted(contents, key=key))
Exemplo n.º 4
0
def main(argv=()):
    with errors.clean_commandline_error_exit():
        if len(argv) > 1:
            errors.log_and_raise(
                'Command line parsing failure: postprocess_variants does not accept '
                'positional arguments but some are present on the command line: '
                '"{}".'.format(str(argv)), errors.CommandLineError)
        del argv  # Unused.

        if (not FLAGS.nonvariant_site_tfrecord_path) != (
                not FLAGS.gvcf_outfile):
            errors.log_and_raise(
                'gVCF creation requires both nonvariant_site_tfrecord_path and '
                'gvcf_outfile flags to be set.', errors.CommandLineError)

        proto_utils.uses_fast_cpp_protos_or_die()

        logging_level.set_from_flag()

        fasta_reader = fasta.IndexedFastaReader(FLAGS.ref,
                                                cache_size=_FASTA_CACHE_SIZE)
        contigs = fasta_reader.header.contigs
        paths = io_utils.maybe_generate_sharded_filenames(FLAGS.infile)
        # Read one CallVariantsOutput record and extract the sample name from it.
        # Note that this assumes that all CallVariantsOutput protos in the infile
        # contain a single VariantCall within their constituent Variant proto, and
        # that the call_set_name is identical in each of the records.
        record = tf_utils.get_one_example_from_examples_path(
            ','.join(paths), proto=deepvariant_pb2.CallVariantsOutput)
        if record is None:
            raise ValueError('Cannot find any records in {}'.format(
                ','.join(paths)))

        sample_name = _extract_single_sample_name(record)
        header = dv_vcf_constants.deepvariant_header(
            contigs=contigs, sample_names=[sample_name])
        with tempfile.NamedTemporaryFile() as temp:
            postprocess_variants_lib.process_single_sites_tfrecords(
                contigs, paths, temp.name)
            independent_variants = _transform_call_variants_output_to_variants(
                input_sorted_tfrecord_path=temp.name,
                qual_filter=FLAGS.qual_filter,
                multi_allelic_qual_filter=FLAGS.multi_allelic_qual_filter,
                sample_name=sample_name)
            variant_generator = haplotypes.maybe_resolve_conflicting_variants(
                independent_variants)
            write_variants_to_vcf(variant_generator=variant_generator,
                                  output_vcf_path=FLAGS.outfile,
                                  header=header)

        # Also write out the gVCF file if it was provided.
        if FLAGS.nonvariant_site_tfrecord_path:
            nonvariant_generator = io_utils.read_shard_sorted_tfrecords(
                FLAGS.nonvariant_site_tfrecord_path,
                key=_get_contig_based_variant_sort_keyfn(contigs),
                proto=variants_pb2.Variant)
            with vcf.VcfReader(FLAGS.outfile) as variant_reader:
                lessthanfn = _get_contig_based_lessthan(contigs)
                gvcf_variants = (_transform_to_gvcf_record(variant)
                                 for variant in variant_reader.iterate())
                merged_variants = merge_variants_and_nonvariants(
                    gvcf_variants, nonvariant_generator, lessthanfn,
                    fasta_reader)
                write_variants_to_vcf(variant_generator=merged_variants,
                                      output_vcf_path=FLAGS.gvcf_outfile,
                                      header=header)
Exemplo n.º 5
0
def main(argv=()):
  with errors.clean_commandline_error_exit():
    if len(argv) > 1:
      errors.log_and_raise(
          'Command line parsing failure: postprocess_variants does not accept '
          'positional arguments but some are present on the command line: '
          '"{}".'.format(str(argv)), errors.CommandLineError)
    del argv  # Unused.

    if (not FLAGS.nonvariant_site_tfrecord_path) != (not FLAGS.gvcf_outfile):
      errors.log_and_raise(
          'gVCF creation requires both nonvariant_site_tfrecord_path and '
          'gvcf_outfile flags to be set.', errors.CommandLineError)

    proto_utils.uses_fast_cpp_protos_or_die()

    logging_level.set_from_flag()

    fasta_reader = fasta.RefFastaReader(FLAGS.ref, cache_size=_FASTA_CACHE_SIZE)
    contigs = fasta_reader.header.contigs
    paths = io_utils.maybe_generate_sharded_filenames(FLAGS.infile)
    # Read one CallVariantsOutput record and extract the sample name from it.
    # Note that this assumes that all CallVariantsOutput protos in the infile
    # contain a single VariantCall within their constituent Variant proto, and
    # that the call_set_name is identical in each of the records.
    record = next(
        io_utils.read_tfrecords(
            paths[0], proto=deepvariant_pb2.CallVariantsOutput, max_records=1))
    sample_name = _extract_single_sample_name(record)
    header = dv_vcf_constants.deepvariant_header(
        contigs=contigs, sample_names=[sample_name])
    with tempfile.NamedTemporaryFile() as temp:
      postprocess_variants_lib.process_single_sites_tfrecords(
          contigs, paths, temp.name)
      independent_variants = _transform_call_variants_output_to_variants(
          input_sorted_tfrecord_path=temp.name,
          qual_filter=FLAGS.qual_filter,
          multi_allelic_qual_filter=FLAGS.multi_allelic_qual_filter,
          sample_name=sample_name)
      variant_generator = haplotypes.maybe_resolve_conflicting_variants(
          independent_variants)
      write_variants_to_vcf(
          variant_generator=variant_generator,
          output_vcf_path=FLAGS.outfile,
          header=header)

    # Also write out the gVCF file if it was provided.
    if FLAGS.nonvariant_site_tfrecord_path:
      nonvariant_generator = io_utils.read_shard_sorted_tfrecords(
          FLAGS.nonvariant_site_tfrecord_path,
          key=_get_contig_based_variant_sort_keyfn(contigs),
          proto=variants_pb2.Variant)
      with vcf.VcfReader(FLAGS.outfile, use_index=False) as variant_reader:
        lessthanfn = _get_contig_based_lessthan(contigs)
        gvcf_variants = (
            _transform_to_gvcf_record(variant)
            for variant in variant_reader.iterate())
        merged_variants = merge_variants_and_nonvariants(
            gvcf_variants, nonvariant_generator, lessthanfn, fasta_reader)
        write_variants_to_vcf(
            variant_generator=merged_variants,
            output_vcf_path=FLAGS.gvcf_outfile,
            header=header)