def test_read_file_pattern_large(self):
     read_data = self._read_records(
         os.path.join(testdata_util.get_full_dir(), 'valid-*.vcf'))
     self.assertEqual(9900, len(read_data))
     read_data_gz = self._read_records(
         os.path.join(testdata_util.get_full_dir(), 'valid-*.vcf.gz'))
     self.assertEqual(9900, len(read_data_gz))
 def test_pipeline_read_all_gzip_large(self):
     self._assert_pipeline_read_files_record_count_equal(os.path.join(
         testdata_util.get_full_dir(), 'valid-*.vcf.gz'),
                                                         9900,
                                                         use_read_all=True)
 def test_pipeline_read_file_pattern_large(self):
     self._assert_pipeline_read_files_record_count_equal(
         os.path.join(testdata_util.get_full_dir(), 'valid-*.vcf'), 9900)
class VcfSourceTest(unittest.TestCase):

    VCF_FILE_DIR_MISSING = not os.path.exists(testdata_util.get_full_dir())
    try:
        vcfio.vcf_parser.nucleus_vcf_reader
    except AttributeError:
        NUCLEUS_IMPORT_MISSING = True
    else:
        NUCLEUS_IMPORT_MISSING = False

    def _create_temp_vcf_file(self,
                              lines,
                              tempdir,
                              compression_type=CompressionTypes.UNCOMPRESSED):
        if compression_type in (CompressionTypes.UNCOMPRESSED,
                                CompressionTypes.AUTO):
            suffix = '.vcf'
        elif compression_type == CompressionTypes.GZIP:
            suffix = '.vcf.gz'
        elif compression_type == CompressionTypes.BZIP2:
            suffix = '.vcf.bz2'
        else:
            raise ValueError(
                'Unrecognized compression type {}'.format(compression_type))
        return tempdir.create_temp_file(suffix=suffix,
                                        lines=lines,
                                        compression_type=compression_type)

    def _read_records(self,
                      file_or_pattern,
                      representative_header_lines=None,
                      vcf_parser_type=VcfParserType.PYVCF,
                      **kwargs):
        return source_test_utils.read_from_source(
            VcfSource(file_or_pattern,
                      representative_header_lines=representative_header_lines,
                      vcf_parser_type=vcf_parser_type,
                      **kwargs))

    def _create_temp_file_and_read_records(
            self,
            lines,
            representative_header_lines=None,
            vcf_parser_type=VcfParserType.PYVCF):
        with TempDir() as tempdir:
            file_name = tempdir.create_temp_file(suffix='.vcf', lines=lines)
            return self._read_records(file_name, representative_header_lines,
                                      vcf_parser_type)

    def _assert_variants_equal(self, actual, expected):
        self.assertEqual(sorted(expected), sorted(actual))

    def _get_invalid_file_contents(self):
        """Gets sample invalid files contents.

    Returns:
       A `tuple` where the first element is contents that are invalid because
       of record errors and the second element is contents that are invalid
       because of header errors.
    """
        malformed_vcf_records = [
            # Malfromed record.
            [
                '#CHROM	POS	ID	REF	ALT	QUAL	FILTER	INFO	FORMAT	Sample\n',
                '1    1  '
            ],
            # Depending on whether pyvcf uses cython this case fails, this is a
            # known problem: https://github.com/apache/beam/pull/4221
            # Missing "GT:GQ" format, but GQ is provided.
            #[
            #    '#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSample\n',
            #    '19\t123\trs12345\tT\tC\t50\tq10\tAF=0.2;NS=2\tGT\t1|0:48'
            #],
            # GT is not an integer.
            [
                '#CHROM	POS	ID	REF	ALT	QUAL	FILTER	INFO	FORMAT	Sample\n',
                '19	123	rs12345	T	C	50	q10	AF=0.2;NS=2	GT	A|0'
            ],
            # POS should be an integer.
            [
                '##FILTER=<ID=PASS,Description="All filters passed">\n',
                '##FILTER=<ID=q10,Description="Quality is less than 10.">\n',
                '#CHROM	POS	ID	REF	ALT	QUAL	FILTER	INFO	FORMAT	Sample\n',
                '19	abc	rs12345	T	C	9	q10	AF=0.2;NS=2	GT:GQ	1|0:48\n',
            ]
        ]
        malformed_header_lines = [
            # Malformed FILTER.
            [
                '##FILTER=<ID=PASS,Description="All filters passed">\n',
                '##FILTER=<ID=LowQual,Descri\n',
                '#CHROM	POS	ID	REF	ALT	QUAL	FILTER	INFO	FORMAT	Sample\n',
                '19	123	rs12345	T	C	50	q10	AF=0.2;NS=2	GT:GQ	1|0:48',
            ],
            # Invalid Number value for INFO.
            [
                '##INFO=<ID=G,Number=U,Type=String,Description="InvalidNumber">\n',
                '#CHROM	POS	ID	REF	ALT	QUAL	FILTER	INFO	FORMAT	Sample\n',
                '19	123	rs12345	T	C	50	q10	AF=0.2;NS=2	GT:GQ	1|0:48\n',
            ]
        ]

        return (malformed_vcf_records, malformed_header_lines)

    @unittest.skipIf(VCF_FILE_DIR_MISSING,
                     'VCF test file directory is missing')
    def test_read_single_file_large(self):
        test_data_conifgs = [
            {
                'file': 'valid-4.0.vcf',
                'num_records': 5
            },
            {
                'file': 'valid-4.0.vcf.gz',
                'num_records': 5
            },
            {
                'file': 'valid-4.0.vcf.bz2',
                'num_records': 5
            },
            {
                'file': 'valid-4.1-large.vcf',
                'num_records': 9882
            },
            {
                'file': 'valid-4.2.vcf',
                'num_records': 13
            },
        ]
        for config in test_data_conifgs:
            read_data = self._read_records(
                testdata_util.get_full_file_path(config['file']))
            self.assertEqual(config['num_records'], len(read_data))

    @unittest.skipIf(VCF_FILE_DIR_MISSING,
                     'VCF test file directory is missing')
    def test_read_file_pattern_large(self):
        read_data = self._read_records(
            os.path.join(testdata_util.get_full_dir(), 'valid-*.vcf'))
        self.assertEqual(9900, len(read_data))
        read_data_gz = self._read_records(
            os.path.join(testdata_util.get_full_dir(), 'valid-*.vcf.gz'))
        self.assertEqual(9900, len(read_data_gz))

    def test_single_file_no_records(self):
        for content in [[''], [' '], ['', ' ', '\n'], ['\n', '\r\n', '\n']]:
            self.assertEqual([],
                             self._create_temp_file_and_read_records(content))
            self.assertEqual([],
                             self._create_temp_file_and_read_records(
                                 content, _SAMPLE_HEADER_LINES))

    def test_single_file_verify_details(self):
        variant_1, vcf_line_1 = _get_sample_variant_1()
        read_data = self._create_temp_file_and_read_records(
            _SAMPLE_HEADER_LINES + [vcf_line_1])
        self.assertEqual(1, len(read_data))
        self.assertEqual(variant_1, read_data[0])

        variant_2, vcf_line_2 = _get_sample_variant_2()
        variant_3, vcf_line_3 = _get_sample_variant_3()
        read_data = self._create_temp_file_and_read_records(
            _SAMPLE_HEADER_LINES + [vcf_line_1, vcf_line_2, vcf_line_3])
        self.assertEqual(3, len(read_data))
        self._assert_variants_equal([variant_1, variant_2, variant_3],
                                    read_data)

    @unittest.skipIf(NUCLEUS_IMPORT_MISSING, 'Nucleus is not imported')
    def test_single_file_verify_details_nucleus(self):
        variant_1, vcf_line_1 = _get_sample_variant_1(is_for_nucleus=True)
        read_data = self._create_temp_file_and_read_records(
            _NUCLEUS_HEADER_LINES + [vcf_line_1],
            vcf_parser_type=VcfParserType.NUCLEUS)
        self.assertEqual(1, len(read_data))
        self.assertEqual(variant_1, read_data[0])

        variant_2, vcf_line_2 = _get_sample_variant_2(is_for_nucleus=True)
        variant_3, vcf_line_3 = _get_sample_variant_3(is_for_nucleus=True)
        read_data = self._create_temp_file_and_read_records(
            _NUCLEUS_HEADER_LINES + [vcf_line_1, vcf_line_2, vcf_line_3],
            vcf_parser_type=VcfParserType.NUCLEUS)
        self.assertEqual(3, len(read_data))
        self._assert_variants_equal([variant_1, variant_2, variant_3],
                                    read_data)

    def test_file_pattern_verify_details(self):
        variant_1, vcf_line_1 = _get_sample_variant_1()
        variant_2, vcf_line_2 = _get_sample_variant_2()
        variant_3, vcf_line_3 = _get_sample_variant_3()
        with TempDir() as tempdir:
            self._create_temp_vcf_file(_SAMPLE_HEADER_LINES + [vcf_line_1],
                                       tempdir)
            self._create_temp_vcf_file(
                (_SAMPLE_HEADER_LINES + [vcf_line_2, vcf_line_3]), tempdir)
            read_data = self._read_records(
                os.path.join(tempdir.get_path(), '*.vcf'))
            self.assertEqual(3, len(read_data))
            self._assert_variants_equal([variant_1, variant_2, variant_3],
                                        read_data)

    @unittest.skipIf(NUCLEUS_IMPORT_MISSING, 'Nucleus is not imported')
    def test_file_pattern_verify_details_nucleus(self):
        variant_1, vcf_line_1 = _get_sample_variant_1(is_for_nucleus=True)
        variant_2, vcf_line_2 = _get_sample_variant_2(is_for_nucleus=True)
        variant_3, vcf_line_3 = _get_sample_variant_3(is_for_nucleus=True)
        with TempDir() as tempdir:
            self._create_temp_vcf_file(_NUCLEUS_HEADER_LINES + [vcf_line_1],
                                       tempdir)
            self._create_temp_vcf_file(
                (_NUCLEUS_HEADER_LINES + [vcf_line_2, vcf_line_3]), tempdir)
            read_data = self._read_records(
                os.path.join(tempdir.get_path(), '*.vcf'),
                vcf_parser_type=VcfParserType.NUCLEUS)
            self.assertEqual(3, len(read_data))
            self._assert_variants_equal([variant_1, variant_2, variant_3],
                                        read_data)

    @unittest.skipIf(VCF_FILE_DIR_MISSING,
                     'VCF test file directory is missing')
    def test_read_after_splitting(self):
        file_name = testdata_util.get_full_file_path('valid-4.1-large.vcf')
        source = VcfSource(file_name)
        splits = [p for p in source.split(desired_bundle_size=500)]
        self.assertGreater(len(splits), 1)
        sources_info = ([(split.source, split.start_position,
                          split.stop_position) for split in splits])
        self.assertGreater(len(sources_info), 1)
        split_records = []
        for source_info in sources_info:
            split_records.extend(
                source_test_utils.read_from_source(*source_info))
        self.assertEqual(9882, len(split_records))

    def test_invalid_file(self):
        invalid_file_contents = self._get_invalid_file_contents()

        for content in chain(*invalid_file_contents):
            with TempDir() as tempdir, self.assertRaises(ValueError):
                self._read_records(self._create_temp_vcf_file(
                    content, tempdir))
                self.fail('Invalid VCF file must throw an exception')
        # Try with multiple files (any one of them will throw an exception).
        with TempDir() as tempdir, self.assertRaises(ValueError):
            for content in chain(*invalid_file_contents):
                self._create_temp_vcf_file(content, tempdir)
                self._read_records(os.path.join(tempdir.get_path(), '*.vcf'))

    def test_allow_malformed_records(self):
        invalid_records, invalid_headers = self._get_invalid_file_contents()

        # Invalid records should not raise errors
        for content in invalid_records:
            with TempDir() as tempdir:
                self._read_records(self._create_temp_vcf_file(
                    content, tempdir),
                                   allow_malformed_records=True)
        # Invalid headers should still raise errors
        for content in invalid_headers:
            with TempDir() as tempdir, self.assertRaises(ValueError):
                self._read_records(self._create_temp_vcf_file(
                    content, tempdir),
                                   allow_malformed_records=True)

    def test_no_samples(self):
        header_line = '#CHROM	POS	ID	REF	ALT	QUAL	FILTER	INFO\n'
        record_line = '19	123	.	G	A	.	PASS	AF=0.2'
        expected_variant = Variant(reference_name='19',
                                   start=122,
                                   end=123,
                                   reference_bases='G',
                                   alternate_bases=['A'],
                                   filters=['PASS'],
                                   info={'AF': [0.2]})
        read_data = self._create_temp_file_and_read_records(
            _SAMPLE_HEADER_LINES[:-1] + [header_line, record_line])
        self.assertEqual(1, len(read_data))
        self.assertEqual(expected_variant, read_data[0])

    def test_no_info(self):
        record_line = 'chr19	123	.	.	.	.	.	.	GT	.	.'
        expected_variant = Variant(reference_name='chr19', start=122, end=123)
        expected_variant.calls.append(
            VariantCall(name='Sample1',
                        genotype=[vcfio.MISSING_GENOTYPE_VALUE]))
        expected_variant.calls.append(
            VariantCall(name='Sample2',
                        genotype=[vcfio.MISSING_GENOTYPE_VALUE]))
        read_data = self._create_temp_file_and_read_records(
            _SAMPLE_HEADER_LINES + [record_line])
        self.assertEqual(1, len(read_data))
        self.assertEqual(expected_variant, read_data[0])

    def test_info_numbers_and_types(self):
        info_headers = [
            '##INFO=<ID=HA,Number=A,Type=String,Description="StringInfo_A">\n',
            '##INFO=<ID=HG,Number=G,Type=Integer,Description="IntInfo_G">\n',
            '##INFO=<ID=HR,Number=R,Type=Character,Description="ChrInfo_R">\n',
            '##INFO=<ID=HF,Number=0,Type=Flag,Description="FlagInfo">\n',
            '##INFO=<ID=HU,Number=.,Type=Float,Description="FloatInfo_variable">\n'
        ]
        record_lines = [
            '19	2	.	A	T,C	.	.	HA=a1,a2;HG=1,2,3;HR=a,b,c;HF;HU=0.1	GT	1/0	0/1\n',
            '19	124	.	A	T	.	.	HG=3,4,5;HR=d,e;HU=1.1,1.2	GT	0/0	0/1'
        ]
        variant_1 = Variant(reference_name='19',
                            start=1,
                            end=2,
                            reference_bases='A',
                            alternate_bases=['T', 'C'],
                            info={
                                'HA': ['a1', 'a2'],
                                'HG': [1, 2, 3],
                                'HR': ['a', 'b', 'c'],
                                'HF': True,
                                'HU': [0.1]
                            })
        variant_1.calls.append(VariantCall(name='Sample1', genotype=[1, 0]))
        variant_1.calls.append(VariantCall(name='Sample2', genotype=[0, 1]))
        variant_2 = Variant(reference_name='19',
                            start=123,
                            end=124,
                            reference_bases='A',
                            alternate_bases=['T'],
                            info={
                                'HG': [3, 4, 5],
                                'HR': ['d', 'e'],
                                'HU': [1.1, 1.2]
                            })
        variant_2.calls.append(VariantCall(name='Sample1', genotype=[0, 0]))
        variant_2.calls.append(VariantCall(name='Sample2', genotype=[0, 1]))
        read_data = self._create_temp_file_and_read_records(
            info_headers + _SAMPLE_HEADER_LINES[1:] + record_lines)
        self.assertEqual(2, len(read_data))
        self._assert_variants_equal([variant_1, variant_2], read_data)

    def test_use_of_representative_header(self):
        # Info field `HU` is defined as Float in file header while data is String.
        # This results in parser failure. We test if parser completes successfully
        # when a representative headers with String definition for field `HU` is
        # given.
        file_content = [
            '##INFO=<ID=HU,Number=.,Type=Float,Description="Info">\n',
            '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\r\n',
            '#CHROM	POS	ID	REF	ALT	QUAL	FILTER	INFO	FORMAT	Sample1	Sample2\r\n',
            '19	2	.	A	T	.	.	HU=a,b	GT	0/0	0/1\n',
        ]
        representative_header_lines = [
            '##INFO=<ID=HU,Number=.,Type=String,Description="Info">\n',
            '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\r\n',
        ]
        variant = Variant(reference_name='19',
                          start=1,
                          end=2,
                          reference_bases='A',
                          alternate_bases=['T'],
                          info={'HU': ['a', 'b']})
        variant.calls.append(VariantCall(name='Sample1', genotype=[0, 0]))
        variant.calls.append(VariantCall(name='Sample2', genotype=[0, 1]))

        # `file_headers` is used.
        with self.assertRaises(ValueError):
            read_data = self._create_temp_file_and_read_records(file_content)

        # `representative_header` is used.
        read_data = self._create_temp_file_and_read_records(
            file_content, representative_header_lines)
        self.assertEqual(1, len(read_data))
        self._assert_variants_equal([variant], read_data)

    def test_use_of_representative_header_two_files(self):
        # Info field `HU` is defined as Float in file header while data is String.
        # This results in parser failure. We test if parser completes successfully
        # when a representative headers with String definition for field `HU` is
        # given.
        file_content_1 = [
            '##INFO=<ID=HU,Number=.,Type=Float,Description="Info">\n',
            '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\r\n',
            '#CHROM	POS	ID	REF	ALT	QUAL	FILTER	INFO	FORMAT	Sample1\r\n',
            '9     2       .       A       T       .       .       HU=a,b  GT 0/0'
        ]
        file_content_2 = [
            '##INFO=<ID=HU,Number=.,Type=Float,Description="Info">\n',
            '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\r\n',
            '#CHROM	POS	ID	REF	ALT	QUAL	FILTER	INFO	FORMAT	Sample2\r\n',
            '19	2	.	A	T	.	.	HU=a,b	GT	0/1\n',
        ]
        representative_header_lines = [
            '##INFO=<ID=HU,Number=.,Type=String,Description="Info">\n',
            '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\r\n',
        ]

        variant_1 = Variant(reference_name='9',
                            start=1,
                            end=2,
                            reference_bases='A',
                            alternate_bases=['T'],
                            info={'HU': ['a', 'b']})
        variant_1.calls.append(VariantCall(name='Sample1', genotype=[0, 0]))

        variant_2 = Variant(reference_name='19',
                            start=1,
                            end=2,
                            reference_bases='A',
                            alternate_bases=['T'],
                            info={'HU': ['a', 'b']})
        variant_2.calls.append(VariantCall(name='Sample2', genotype=[0, 1]))

        read_data_1 = self._create_temp_file_and_read_records(
            file_content_1, representative_header_lines)
        self.assertEqual(1, len(read_data_1))
        self._assert_variants_equal([variant_1], read_data_1)

        read_data_2 = self._create_temp_file_and_read_records(
            file_content_2, representative_header_lines)
        self.assertEqual(1, len(read_data_2))
        self._assert_variants_equal([variant_2], read_data_2)

    def test_end_info_key(self):
        end_info_header_line = (
            '##INFO=<ID=END,Number=1,Type=Integer,Description="End of record.">\n'
        )
        record_lines = [
            '19	123	.	A	.	.	.	END=1111	GT	1/0	0/1\n',
            '19	123	.	A	.	.	.	.	GT	0/1	1/1\n'
        ]
        variant_1 = Variant(reference_name='19',
                            start=122,
                            end=1111,
                            reference_bases='A')
        variant_1.calls.append(VariantCall(name='Sample1', genotype=[1, 0]))
        variant_1.calls.append(VariantCall(name='Sample2', genotype=[0, 1]))
        variant_2 = Variant(reference_name='19',
                            start=122,
                            end=123,
                            reference_bases='A')
        variant_2.calls.append(VariantCall(name='Sample1', genotype=[0, 1]))
        variant_2.calls.append(VariantCall(name='Sample2', genotype=[1, 1]))
        read_data = self._create_temp_file_and_read_records(
            [end_info_header_line] + _SAMPLE_HEADER_LINES[1:] + record_lines)
        self.assertEqual(2, len(read_data))
        self._assert_variants_equal([variant_1, variant_2], read_data)

    def test_end_info_key_unknown_number(self):
        end_info_header_line = (
            '##INFO=<ID=END,Number=.,Type=Integer,Description="End of record.">\n'
        )
        record_lines = ['19	123	.	A	.	.	.	END=1111	GT	1/0	0/1\n']
        variant_1 = Variant(reference_name='19',
                            start=122,
                            end=1111,
                            reference_bases='A')
        variant_1.calls.append(VariantCall(name='Sample1', genotype=[1, 0]))
        variant_1.calls.append(VariantCall(name='Sample2', genotype=[0, 1]))
        read_data = self._create_temp_file_and_read_records(
            [end_info_header_line] + _SAMPLE_HEADER_LINES[1:] + record_lines)
        self.assertEqual(1, len(read_data))
        self._assert_variants_equal([variant_1], read_data)

    def test_end_info_key_unknown_number_invalid(self):
        end_info_header_line = (
            '##INFO=<ID=END,Number=.,Type=Integer,Description="End of record.">\n'
        )
        # END should only have one field.
        with self.assertRaises(ValueError):
            self._create_temp_file_and_read_records(
                [end_info_header_line] + _SAMPLE_HEADER_LINES[1:] +
                ['19	124	.	A	.	.	.	END=150,160	GT	1/0	0/1\n'])
        # END should be an integer.
        with self.assertRaises(ValueError):
            self._create_temp_file_and_read_records(
                [end_info_header_line] + _SAMPLE_HEADER_LINES[1:] +
                ['19	124	.	A	.	.	.	END=150.1	GT	1/0	0/1\n'])

    def test_custom_phaseset(self):
        phaseset_header_line = (
            '##FORMAT=<ID=PS,Number=1,Type=Integer,Description="Phaseset">\n')
        record_lines = [
            '19	123	.	A	T	.	.	.	GT:PS	1|0:1111	0/1:.\n',
            '19	121	.	A	T	.	.	.	GT:PS	1|0:2222	0/1:2222\n'
        ]
        variant_1 = Variant(reference_name='19',
                            start=122,
                            end=123,
                            reference_bases='A',
                            alternate_bases=['T'])
        variant_1.calls.append(
            VariantCall(name='Sample1', genotype=[1, 0], phaseset='1111'))
        variant_1.calls.append(VariantCall(name='Sample2', genotype=[0, 1]))
        variant_2 = Variant(reference_name='19',
                            start=120,
                            end=121,
                            reference_bases='A',
                            alternate_bases=['T'])
        variant_2.calls.append(
            VariantCall(name='Sample1', genotype=[1, 0], phaseset='2222'))
        variant_2.calls.append(
            VariantCall(name='Sample2', genotype=[0, 1], phaseset='2222'))
        read_data = self._create_temp_file_and_read_records(
            [phaseset_header_line] + _SAMPLE_HEADER_LINES[1:] + record_lines)
        self.assertEqual(2, len(read_data))
        self._assert_variants_equal([variant_1, variant_2], read_data)

    def test_format_numbers(self):
        format_headers = [
            '##FORMAT=<ID=FU,Number=.,Type=String,Description="Format_variable">\n',
            '##FORMAT=<ID=F1,Number=1,Type=Integer,Description="Format_1">\n',
            '##FORMAT=<ID=F2,Number=2,Type=Character,Description="Format_2">\n',
            '##FORMAT=<ID=AO,Number=A,Type=Integer,Description="Format_3">\n',
            '##FORMAT=<ID=AD,Number=G,Type=Integer,Description="Format_4">\n',
        ]

        record_lines = [('19	2	.	A	T,C	.	.	.	'
                         'GT:FU:F1:F2:AO:AD	1/0:a1:3:a,b:1:3,4	'
                         '0/1:a2,a3:4:b,c:1,2:3')]
        expected_variant = Variant(reference_name='19',
                                   start=1,
                                   end=2,
                                   reference_bases='A',
                                   alternate_bases=['T', 'C'])
        expected_variant.calls.append(
            VariantCall(name='Sample1',
                        genotype=[1, 0],
                        info={
                            'FU': ['a1'],
                            'F1': 3,
                            'F2': ['a', 'b'],
                            'AO': [1],
                            'AD': [3, 4]
                        }))
        expected_variant.calls.append(
            VariantCall(name='Sample2',
                        genotype=[0, 1],
                        info={
                            'FU': ['a2', 'a3'],
                            'F1': 4,
                            'F2': ['b', 'c'],
                            'AO': [1, 2],
                            'AD': [3]
                        }))
        read_data = self._create_temp_file_and_read_records(
            format_headers + _SAMPLE_HEADER_LINES[1:] + record_lines)
        self.assertEqual(1, len(read_data))
        self.assertEqual(expected_variant, read_data[0])

    def _assert_pipeline_read_files_record_count_equal(self,
                                                       input_pattern,
                                                       expected_count,
                                                       use_read_all=False):
        """Helper method for verifying total records read.

    Args:
      input_pattern (str): Input file pattern to read.
      expected_count (int): Expected number of reacords that was read.
      use_read_all (bool): Whether to use the scalable ReadAllFromVcf transform
        instead of ReadFromVcf.
    """
        pipeline = TestPipeline()
        if use_read_all:
            pcoll = (pipeline
                     | 'Create' >> beam.Create([input_pattern])
                     | 'Read' >> ReadAllFromVcf())
        else:
            pcoll = pipeline | 'Read' >> ReadFromVcf(input_pattern)
        assert_that(pcoll, asserts.count_equals_to(expected_count))
        pipeline.run()

    def test_pipeline_read_single_file(self):
        with TempDir() as tempdir:
            file_name = self._create_temp_vcf_file(
                _SAMPLE_HEADER_LINES + _SAMPLE_TEXT_LINES, tempdir)
            self._assert_pipeline_read_files_record_count_equal(
                file_name, len(_SAMPLE_TEXT_LINES))

    def test_pipeline_read_all_single_file(self):
        with TempDir() as tempdir:
            file_name = self._create_temp_vcf_file(
                _SAMPLE_HEADER_LINES + _SAMPLE_TEXT_LINES, tempdir)
            self._assert_pipeline_read_files_record_count_equal(
                file_name, len(_SAMPLE_TEXT_LINES), use_read_all=True)

    @unittest.skipIf(VCF_FILE_DIR_MISSING,
                     'VCF test file directory is missing')
    def test_pipeline_read_single_file_large(self):
        self._assert_pipeline_read_files_record_count_equal(
            testdata_util.get_full_file_path('valid-4.1-large.vcf'), 9882)

    @unittest.skipIf(VCF_FILE_DIR_MISSING,
                     'VCF test file directory is missing')
    def test_pipeline_read_all_single_file_large(self):
        self._assert_pipeline_read_files_record_count_equal(
            testdata_util.get_full_file_path('valid-4.1-large.vcf'),
            9882,
            use_read_all=True)

    @unittest.skipIf(VCF_FILE_DIR_MISSING,
                     'VCF test file directory is missing')
    def test_pipeline_read_file_pattern_large(self):
        self._assert_pipeline_read_files_record_count_equal(
            os.path.join(testdata_util.get_full_dir(), 'valid-*.vcf'), 9900)

    @unittest.skipIf(VCF_FILE_DIR_MISSING,
                     'VCF test file directory is missing')
    def test_pipeline_read_all_file_pattern_large(self):
        self._assert_pipeline_read_files_record_count_equal(os.path.join(
            testdata_util.get_full_dir(), 'valid-*.vcf'),
                                                            9900,
                                                            use_read_all=True)

    @unittest.skipIf(VCF_FILE_DIR_MISSING,
                     'VCF test file directory is missing')
    def test_pipeline_read_all_gzip_large(self):
        self._assert_pipeline_read_files_record_count_equal(os.path.join(
            testdata_util.get_full_dir(), 'valid-*.vcf.gz'),
                                                            9900,
                                                            use_read_all=True)

    @unittest.skipIf(VCF_FILE_DIR_MISSING,
                     'VCF test file directory is missing')
    def test_pipeline_read_all_multiple_files_large(self):
        pipeline = TestPipeline()
        pcoll = (pipeline
                 | 'Create' >> beam.Create([
                     testdata_util.get_full_file_path('valid-4.0.vcf'),
                     testdata_util.get_full_file_path('valid-4.1-large.vcf'),
                     testdata_util.get_full_file_path('valid-4.2.vcf')
                 ])
                 | 'Read' >> ReadAllFromVcf())
        assert_that(pcoll, asserts.count_equals_to(9900))
        pipeline.run()

    def test_pipeline_read_all_gzip(self):
        with TempDir() as tempdir:
            file_name_1 = self._create_temp_vcf_file(
                _SAMPLE_HEADER_LINES + _SAMPLE_TEXT_LINES,
                tempdir,
                compression_type=CompressionTypes.GZIP)
            file_name_2 = self._create_temp_vcf_file(
                _SAMPLE_HEADER_LINES + _SAMPLE_TEXT_LINES,
                tempdir,
                compression_type=CompressionTypes.GZIP)
            pipeline = TestPipeline()
            pcoll = (pipeline
                     | 'Create' >> beam.Create([file_name_1, file_name_2])
                     | 'Read' >> ReadAllFromVcf())
            assert_that(pcoll,
                        asserts.count_equals_to(2 * len(_SAMPLE_TEXT_LINES)))
            pipeline.run()

    def test_pipeline_read_all_bzip2(self):
        with TempDir() as tempdir:
            file_name_1 = self._create_temp_vcf_file(
                _SAMPLE_HEADER_LINES + _SAMPLE_TEXT_LINES,
                tempdir,
                compression_type=CompressionTypes.BZIP2)
            file_name_2 = self._create_temp_vcf_file(
                _SAMPLE_HEADER_LINES + _SAMPLE_TEXT_LINES,
                tempdir,
                compression_type=CompressionTypes.BZIP2)
            pipeline = TestPipeline()
            pcoll = (pipeline
                     | 'Create' >> beam.Create([file_name_1, file_name_2])
                     | 'Read' >> ReadAllFromVcf())
            assert_that(pcoll,
                        asserts.count_equals_to(2 * len(_SAMPLE_TEXT_LINES)))
            pipeline.run()

    def test_pipeline_read_all_multiple_files(self):
        with TempDir() as tempdir:
            file_name_1 = self._create_temp_vcf_file(
                _SAMPLE_HEADER_LINES + _SAMPLE_TEXT_LINES, tempdir)
            file_name_2 = self._create_temp_vcf_file(
                _SAMPLE_HEADER_LINES + _SAMPLE_TEXT_LINES, tempdir)
            pipeline = TestPipeline()
            pcoll = (pipeline
                     | 'Create' >> beam.Create([file_name_1, file_name_2])
                     | 'Read' >> ReadAllFromVcf())
            assert_that(pcoll,
                        asserts.count_equals_to(2 * len(_SAMPLE_TEXT_LINES)))
            pipeline.run()

    def test_read_reentrant_without_splitting(self):
        with TempDir() as tempdir:
            file_name = self._create_temp_vcf_file(
                _SAMPLE_HEADER_LINES + _SAMPLE_TEXT_LINES, tempdir)
            source = VcfSource(file_name)
            source_test_utils.assert_reentrant_reads_succeed(
                (source, None, None))

    def test_read_reentrant_after_splitting(self):
        with TempDir() as tempdir:
            file_name = self._create_temp_vcf_file(
                _SAMPLE_HEADER_LINES + _SAMPLE_TEXT_LINES, tempdir)
            source = VcfSource(file_name)
            splits = [
                split for split in source.split(desired_bundle_size=100000)
            ]
            assert len(splits) == 1
            source_test_utils.assert_reentrant_reads_succeed(
                (splits[0].source, splits[0].start_position,
                 splits[0].stop_position))

    def test_dynamic_work_rebalancing(self):
        with TempDir() as tempdir:
            file_name = self._create_temp_vcf_file(
                _SAMPLE_HEADER_LINES + _SAMPLE_TEXT_LINES, tempdir)
            source = VcfSource(file_name)
            splits = [
                split for split in source.split(desired_bundle_size=100000)
            ]
            assert len(splits) == 1
            source_test_utils.assert_split_at_fraction_exhaustive(
                splits[0].source, splits[0].start_position,
                splits[0].stop_position)
Exemple #5
0
class VcfHeaderSourceTest(unittest.TestCase):

  # TODO(msaul): Replace get_full_dir() with function from utils.
  # Distribution should skip tests that need VCF files due to large size
  VCF_FILE_DIR_MISSING = not os.path.exists(testdata_util.get_full_dir())

  def setUp(self):
    self.lines = testdata_util.get_sample_vcf_header_lines()

  def _create_file_and_read_headers(self):
    with temp_dir.TempDir() as tempdir:
      filename = tempdir.create_temp_file(suffix='.vcf', lines=self.lines)
      headers = source_test_utils.read_from_source(
          VcfHeaderSource(filename))
      return headers[0]

  def test_vcf_header_eq(self):
    header_1 = _get_vcf_header_from_lines(self.lines)
    header_2 = _get_vcf_header_from_lines(self.lines)
    self.assertEqual(header_1, header_2)

  def test_read_file_headers(self):
    headers = self.lines
    self.lines = testdata_util.get_sample_vcf_file_lines()
    header = self._create_file_and_read_headers()
    self.assertEqual(header, _get_vcf_header_from_lines(headers))

  def test_malformed_headers(self):
    # TODO(tneymanov): Add more tests.
    malformed_header_lines = [
        # Malformed FILTER.
        [
            '##FILTER=<ID=PASS,Description="All filters passed">\n',
            '##FILTER=<ID=LowQual,Descri\n',
            '#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSample\n',
            '19\t123\trs12345\tT\tC\t50\tq10\tAF=0.2;NS=2\tGT:GQ\t1|0:48'
        ]
    ]

    for content in malformed_header_lines:
      self.lines = content
      with self.assertRaises(ValueError):
        self._create_file_and_read_headers()

  def test_all_fields(self):
    self.lines = [
        '##contig=<ID=M,length=16,assembly=B37,md5=c6,species="Homosapiens">\n',
        '##contig=<ID=P,length=16,assembly=B37,md5=c6,species="Homosapiens">\n',
        '\n',
        '##ALT=<ID=CGA_CNVWIN,Description="Copy number analysis window">\n',
        '##ALT=<ID=INS:ME:MER,Description="Insertion of MER element">\n',
        '##FILTER=<ID=MPCBT,Description="Mate pair count below 10">\n',
        '##INFO=<ID=CGA_MIRB,Number=.,Type=String,Description="miRBaseId">\n',
        '##FORMAT=<ID=FT,Number=1,Type=String,Description="Genotype filter">\n',
        '#CHROM	POS	ID	REF	ALT	QUAL	FILTER	INFO	FORMAT	GS000016676-ASM\n',
    ]
    header = self._create_file_and_read_headers()
    self.assertCountEqual(list(header.contigs.keys()), ['M', 'P'])
    self.assertCountEqual(
        list(header.alts.keys()), ['CGA_CNVWIN', 'INS:ME:MER'])
    self.assertCountEqual(list(header.filters.keys()), ['MPCBT'])
    self.assertCountEqual(list(header.infos.keys()), ['CGA_MIRB'])
    self.assertCountEqual(list(header.formats.keys()), ['FT'])

  def test_empty_header_raises_error(self):
    self.lines = testdata_util.get_sample_vcf_record_lines()
    with self.assertRaises(ValueError):
      self._create_file_and_read_headers()

  def test_read_file_pattern(self):
    with temp_dir.TempDir() as tempdir:
      headers_1 = [self.lines[1], self.lines[-1]]
      headers_2 = [self.lines[2], self.lines[3], self.lines[-1]]
      headers_3 = [self.lines[4], self.lines[-1]]
      file_name_1 = tempdir.create_temp_file(suffix='.vcf', lines=headers_1)
      file_name_2 = tempdir.create_temp_file(suffix='.vcf', lines=headers_2)
      file_name_3 = tempdir.create_temp_file(suffix='.vcf', lines=headers_3)

      actual = source_test_utils.read_from_source(VcfHeaderSource(
          os.path.join(tempdir.get_path(), '*.vcf')))

      expected = [_get_vcf_header_from_lines(h, file_name=file_name)
                  for h, file_name in [(headers_1, file_name_1),
                                       (headers_2, file_name_2),
                                       (headers_3, file_name_3)]]

      asserts.header_vars_equal(expected)(actual)

  @unittest.skipIf(VCF_FILE_DIR_MISSING, 'VCF test file directory is missing')
  def test_read_single_file_large(self):
    test_data_conifgs = [
        {'file': 'valid-4.0.vcf', 'num_infos': 6, 'num_formats': 4},
        {'file': 'valid-4.0.vcf.gz', 'num_infos': 6, 'num_formats': 4},
        {'file': 'valid-4.0.vcf.bz2', 'num_infos': 6, 'num_formats': 4},
        {'file': 'valid-4.1-large.vcf', 'num_infos': 21, 'num_formats': 33},
        {'file': 'valid-4.2.vcf', 'num_infos': 8, 'num_formats': 5},
    ]
    for config in test_data_conifgs:
      read_data = source_test_utils.read_from_source(VcfHeaderSource(
          testdata_util.get_full_file_path(config['file'])))
      self.assertEqual(config['num_infos'], len(read_data[0].infos))
      self.assertEqual(config['num_formats'], len(read_data[0].formats))

  def test_pipeline_read_file_headers(self):
    headers = self.lines
    self.lines = testdata_util.get_sample_vcf_file_lines()

    with temp_dir.TempDir() as tempdir:
      filename = tempdir.create_temp_file(suffix='.vcf', lines=self.lines)

      pipeline = TestPipeline()
      pcoll = pipeline | 'ReadHeaders' >> ReadVcfHeaders(filename)

      assert_that(pcoll, equal_to([_get_vcf_header_from_lines(headers)]))
      pipeline.run()

  def test_pipeline_read_all_file_headers(self):
    headers = self.lines
    self.lines = testdata_util.get_sample_vcf_file_lines()

    with temp_dir.TempDir() as tempdir:
      filename = tempdir.create_temp_file(suffix='.vcf', lines=self.lines)

      pipeline = TestPipeline()
      pcoll = (pipeline
               | 'Create' >> beam.Create([filename])
               | 'ReadHeaders' >> ReadAllVcfHeaders())

      assert_that(pcoll, equal_to([_get_vcf_header_from_lines(headers)]))
      pipeline.run()

  def test_pipeline_read_file_pattern(self):
    with temp_dir.TempDir() as tempdir:
      headers_1 = [self.lines[1], self.lines[-1]]
      headers_2 = [self.lines[2], self.lines[3], self.lines[-1]]
      headers_3 = [self.lines[4], self.lines[-1]]

      file_name_1 = tempdir.create_temp_file(suffix='.vcf', lines=headers_1)
      file_name_2 = tempdir.create_temp_file(suffix='.vcf', lines=headers_2)
      file_name_3 = tempdir.create_temp_file(suffix='.vcf', lines=headers_3)

      pipeline = TestPipeline()
      pcoll = pipeline | 'ReadHeaders' >> ReadVcfHeaders(
          os.path.join(tempdir.get_path(), '*.vcf'))

      expected = [_get_vcf_header_from_lines(h, file_name=file_name)
                  for h, file_name in [(headers_1, file_name_1),
                                       (headers_2, file_name_2),
                                       (headers_3, file_name_3)]]
      assert_that(pcoll, asserts.header_vars_equal(expected))
      pipeline.run()

  def test_pipeline_read_all_file_pattern(self):
    with temp_dir.TempDir() as tempdir:
      headers_1 = [self.lines[1], self.lines[-1]]
      headers_2 = [self.lines[2], self.lines[3], self.lines[-1]]
      headers_3 = [self.lines[4], self.lines[-1]]

      file_name_1 = tempdir.create_temp_file(suffix='.vcf', lines=headers_1)
      file_name_2 = tempdir.create_temp_file(suffix='.vcf', lines=headers_2)
      file_name_3 = tempdir.create_temp_file(suffix='.vcf', lines=headers_3)

      pipeline = TestPipeline()
      pcoll = (pipeline
               | 'Create' >> beam.Create(
                   [os.path.join(tempdir.get_path(), '*.vcf')])
               | 'ReadHeaders' >> ReadAllVcfHeaders())

      expected = [_get_vcf_header_from_lines(h, file_name=file_name)
                  for h, file_name in [(headers_1, file_name_1),
                                       (headers_2, file_name_2),
                                       (headers_3, file_name_3)]]
      assert_that(pcoll, asserts.header_vars_equal(expected))
      pipeline.run()
Exemple #6
0
class VcfEstimateSourceTest(unittest.TestCase):
    VCF_FILE_DIR_MISSING = not os.path.exists(testdata_util.get_full_dir())

    def setUp(self):
        self.lines = testdata_util.get_sample_vcf_file_lines()
        self.headers = testdata_util.get_sample_vcf_header_lines()
        self.records = testdata_util.get_sample_vcf_record_lines()

    def _create_file_and_read_estimates(self):
        with temp_dir.TempDir() as tempdir:
            filename = tempdir.create_temp_file(suffix='.vcf',
                                                lines=self.lines)
            estimates = source_test_utils.read_from_source(
                VcfEstimateSource(filename))
            return estimates[0]

    def test_vcf_estimate_eq(self):
        estimate_1 = _get_estimate_from_lines(self.lines)
        estimate_2 = _get_estimate_from_lines(self.lines)
        self.assertEqual(estimate_1, estimate_2)

    def test_read_file_estimates(self):
        estimate = self._create_file_and_read_estimates()
        self.assertEqual(
            estimate, _get_estimate_from_lines(self.lines, estimate.file_name))

    def test_empty_header_raises_error(self):
        self.lines = testdata_util.get_sample_vcf_record_lines()
        with self.assertRaises(ValueError):
            self._create_file_and_read_estimates()

    def test_read_file_pattern(self):
        with temp_dir.TempDir() as tempdir:
            lines_1 = self.headers[1:2] + self.headers[-1:] + self.records[:2]
            lines_2 = self.headers[2:4] + self.headers[-1:] + self.records[2:4]
            lines_3 = self.headers[4:5] + self.headers[-1:] + self.records[4:]
            file_name_1 = tempdir.create_temp_file(suffix='.vcf',
                                                   lines=lines_1)
            file_name_2 = tempdir.create_temp_file(suffix='.vcf',
                                                   lines=lines_2)
            file_name_3 = tempdir.create_temp_file(suffix='.vcf',
                                                   lines=lines_3)

            actual = source_test_utils.read_from_source(
                VcfEstimateSource(os.path.join(tempdir.get_path(), '*.vcf')))

            expected = [
                _get_estimate_from_lines(lines, file_name=file_name)
                for lines, file_name in [(
                    lines_1, file_name_1), (lines_2,
                                            file_name_2), (lines_3,
                                                           file_name_3)]
            ]

            asserts.header_vars_equal(expected)(actual)

    @unittest.skipIf(VCF_FILE_DIR_MISSING,
                     'VCF test file directory is missing')
    def test_read_single_file_large(self):
        test_data_conifgs = [
            {
                'file': 'valid-4.0.vcf',
                'variant_count': 4,
                'size': 1500
            },
            {
                'file': 'valid-4.0.vcf.gz',
                'variant_count': 13,
                'size': 1454
            },
            {
                'file': 'valid-4.0.vcf.bz2',
                'variant_count': 14,
                'size': 1562
            },
            {
                'file': 'valid-4.1-large.vcf',
                'variant_count': 14425,
                'size': 832396
            },
            {
                'file': 'valid-4.1-large.vcf.gz',
                'variant_count': 5498,
                'size': 313430
            },
            {
                'file': 'valid-4.2.vcf',
                'variant_count': 10,
                'size': 3195
            },
        ]
        for config in test_data_conifgs:
            read_data = source_test_utils.read_from_source(
                VcfEstimateSource(
                    testdata_util.get_full_file_path(config['file'])))
            self.assertEqual(config['variant_count'],
                             int(read_data[0].estimated_variant_count))
            self.assertEqual(config['size'], read_data[0].size_in_bytes)

    def test_pipeline_read_file_headers(self):

        with temp_dir.TempDir() as tempdir:
            filename = tempdir.create_temp_file(suffix='.vcf',
                                                lines=self.lines)

            pipeline = TestPipeline()
            pcoll = pipeline | 'GetEstimates' >> GetEstimates(filename)

            assert_that(
                pcoll,
                equal_to([_get_estimate_from_lines(self.lines, filename)]))
            pipeline.run()

    def test_pipeline_read_all_file_headers(self):
        with temp_dir.TempDir() as tempdir:
            filename = tempdir.create_temp_file(suffix='.vcf',
                                                lines=self.lines)

            pipeline = TestPipeline()
            pcoll = (pipeline
                     | 'Create' >> beam.Create([filename])
                     | 'GetAllEstimates' >> GetAllEstimates(filename))

            assert_that(
                pcoll,
                equal_to([_get_estimate_from_lines(self.lines, filename)]))
            pipeline.run()

    def test_pipeline_read_file_pattern(self):
        with temp_dir.TempDir() as tempdir:
            lines_1 = self.headers[1:2] + self.headers[-1:] + self.records[:2]
            lines_2 = self.headers[2:4] + self.headers[-1:] + self.records[2:4]
            lines_3 = self.headers[4:5] + self.headers[-1:] + self.records[4:]
            file_name_1 = tempdir.create_temp_file(suffix='.vcf',
                                                   lines=lines_1)
            file_name_2 = tempdir.create_temp_file(suffix='.vcf',
                                                   lines=lines_2)
            file_name_3 = tempdir.create_temp_file(suffix='.vcf',
                                                   lines=lines_3)

            pipeline = TestPipeline()
            pcoll = pipeline | 'ReadHeaders' >> GetEstimates(
                os.path.join(tempdir.get_path(), '*.vcf'))

            expected = [
                _get_estimate_from_lines(lines, file_name=file_name)
                for lines, file_name in [(
                    lines_1, file_name_1), (lines_2,
                                            file_name_2), (lines_3,
                                                           file_name_3)]
            ]
            assert_that(pcoll, asserts.header_vars_equal(expected))
            pipeline.run()

    def test_pipeline_read_all_file_pattern(self):
        with temp_dir.TempDir() as tempdir:
            lines_1 = self.headers[1:2] + self.headers[-1:] + self.records[:2]
            lines_2 = self.headers[2:4] + self.headers[-1:] + self.records[2:4]
            lines_3 = self.headers[4:5] + self.headers[-1:] + self.records[4:]
            file_name_1 = tempdir.create_temp_file(suffix='.vcf',
                                                   lines=lines_1)
            file_name_2 = tempdir.create_temp_file(suffix='.vcf',
                                                   lines=lines_2)
            file_name_3 = tempdir.create_temp_file(suffix='.vcf',
                                                   lines=lines_3)

            pipeline = TestPipeline()
            pcoll = pipeline | 'ReadHeaders' >> GetEstimates(
                os.path.join(tempdir.get_path(), '*.vcf'))
            pcoll = (pipeline
                     | 'Create' >> beam.Create(
                         [os.path.join(tempdir.get_path(), '*.vcf')])
                     | 'GetAllEstimates' >> GetAllEstimates())

            expected = [
                _get_estimate_from_lines(lines, file_name=file_name)
                for lines, file_name in [(
                    lines_1, file_name_1), (lines_2,
                                            file_name_2), (lines_3,
                                                           file_name_3)]
            ]
            assert_that(pcoll, asserts.header_vars_equal(expected))
            pipeline.run()