Example #1
0
 def test_050_check_variant_contents(self):
     expected = [
             Variant('chr1', 14369, 'G', alt=['A'], ident='rs6054257', qual=29, filt='PASS',
                 info={'NS': 3, 'DP': 14, 'AF': 0.5, 'DB': True, 'H2':True},
                 genotype_data=OrderedDict([('GT', '1|0'), ('GQ', '48'), ('DP', '8'), ('HQ', '51,51')]),
             ),
             Variant('chr2', 17329, 'T', alt=['A'], ident='.', qual=3, filt='q10',
                 info={'NS': 3, 'DP': 11, 'AF': 0.017},
                 genotype_data=OrderedDict([('GT', '0|0'), ('GQ', '49'), ('DP', '3'), ('HQ', '58,50')]),
             ),
             Variant('chr10', 1110695, 'A', alt=['G', 'T'], ident='rs6040355', qual=67, filt='PASS',
                 info={'NS': 2, 'DP': 10, 'AF': [0.333,0.667], 'AA': 'T', 'DB': True},
                 genotype_data=OrderedDict([('GT', '1|2'), ('GQ', '21'), ('DP', '6'), ('HQ', '23,27')]),
             ),
             Variant('chr20', 1230236, 'T', alt=['.'], ident='.', qual=47, filt='PASS',
                 info={'NS': 3, 'DP': 13, 'AA': 'T'},
                 genotype_data=OrderedDict([('GT', '0|0'), ('GQ', '54'), ('DP', '7'), ('HQ', '56,60')]),
             ),
             Variant('chrX', 1234566, 'GTCT', alt=['G','GTACT'], ident='microsat1', qual=50, filt='PASS',
                 info={'NS': 3, 'DP': 9, 'AA': 'G'},
                 genotype_data=OrderedDict([('GT', '1/1'), ('GQ', '40'), ('DP', '3')]),
             )
     ]
     result = list(self.vcf_reader.fetch())
     self.assertSequenceEqual(result, expected)
Example #2
0
 def test_007_check_trim_ref(self):
     # if ref and alt are the same, make sure we don't completely remove ref
     # and alt
     v_orig = Variant('20', 14369, ref='TAGTCACAG', alt=['TCACAG'])
     v_expt = Variant('20', 14369, ref='TAGT', alt=['T'])
     v_trim = v_orig.trim()
     self.assertEqual(v_expt, v_trim, 'Trimming failed for {}.'.format(v_expt))
Example #3
0
    def test_002_check_merge_snps(self):
        ref_seq = 'ATGGTATGCGATTGACC'
        chrom = 'chrom1'

        h1 = [
            Variant(chrom, 0, 'A', alt='C', qual=5, sample_dict={'GT': '1/1'})
        ]
        h2 = [
            Variant(chrom, 0, 'A', alt='T', qual=10, sample_dict={'GT': '1/1'})
        ]
        expt = Variant(chrom,
                       0,
                       'A',
                       alt=['C', 'T'],
                       qual=7.5,
                       sample_dict={'GT': '1/2'})

        comb_interval, trees = self.intervaltree_prep(h1, h2, ref_seq)
        # preserve phase otherwise alts could be switched around
        got = _merge_variants(comb_interval,
                              trees,
                              ref_seq,
                              discard_phase=False)
        for key in ('chrom', 'pos', 'ref', 'qual', 'alt', 'gt'):
            expected = getattr(expt, key)
            result = getattr(got, key)
            self.assertEqual(
                expected, result,
                'Merging failed for {}:{} {}.'.format(expt.chrom, expt.pos + 1,
                                                      key))
Example #4
0
 def test_004_check_homo(self):
     v_orig = Variant('20',
                      14369,
                      'G',
                      alt=['T', 'A'],
                      qual=10,
                      sample_dict=OrderedDict([('GT', '1/1'),
                                               ('GQ', 10.0)]))
     sample_dict = v_orig.sample_dict.copy()
     sample_dict['GT'] = '1/1'
     expt = tuple([
         (1,
          Variant('20',
                  14369,
                  'G',
                  alt=['T'],
                  qual=10,
                  sample_dict=sample_dict)),
         (2,
          Variant('20',
                  14369,
                  'G',
                  alt=['T'],
                  qual=10,
                  sample_dict=sample_dict)),
     ])
     got = v_orig.split_haplotypes()
     self.assertEqual(expt, got,
                      'Splitting haplotypes failed for {}'.format(v_orig))
Example #5
0
 def test_005_check_homo(self):
     v_orig = Variant('20',
                      14369,
                      'G',
                      alt=['T', 'A'],
                      qual=10,
                      genotype_data=OrderedDict([('GT', '2/2'),
                                                 ('GQ', 10.0)]))
     genotype_data = v_orig.genotype_data.copy()
     genotype_data['GT'] = '1/1'
     expt = tuple([
         (1,
          Variant('20',
                  14369,
                  'G',
                  alt=['A'],
                  qual=10,
                  genotype_data=genotype_data)),
         (2,
          Variant('20',
                  14369,
                  'G',
                  alt=['A'],
                  qual=10,
                  genotype_data=genotype_data)),
     ])
     got = v_orig.split_haplotypes()
     self.assertEqual(expt, got,
                      'Splitting haplotypes failed for {}'.format(v_orig))
Example #6
0
    def test_vcf_annotate(self):
        variants_annotated = [
                Variant('MN908947.3', 29748, 'ACGATCGAGTG', alt=['A'],
                    ident='.', qual=243.965, filt='PASS',
                    info='AR=0,0;DP=200;DPS=100,100;DPSP=199;SC=19484,20327,22036,23215;SR=1,2,98,98',
                    genotype_data=OrderedDict([('GT','1'), ('GQ', '244')])),
                Variant('MN908947.3', 29764, 'TGAACAATGCT',
                    alt=['A'], ident='.', qual=243.965, filt='PASS',
                    info='AR=0,0;DP=200;DPS=100,100;DPSP=199;SC=19970,21140,15773,16751;SR=99,100,0,0',
                    genotype_data=OrderedDict([('GT','1'), ('GQ', '244')])),
                Variant('MN908947.3', 29788, 'TATATGGAAGA',
                     alt=['A'], ident='.', qual=243.965, filt='PASS',
                    info='AR=0,0;DP=199;DPS=99,100;DPSP=197;SC=26174,28129,19085,20315;SR=96,100,1,0',
                    genotype_data=OrderedDict([('GT', '1'), ('GQ','244')]))]
        variants_annotated = variants_annotated + deepcopy(variants_annotated)
        for i in range(3, 6):
            variants_annotated[i].chrom = "Duplicate"

        with tempfile.NamedTemporaryFile() as vcfout:
            # Annotate vcf
            args = Namespace(RG=self.rg, vcf=self.vcf,ref_fasta=self.ref_fasta,
                            bam=self.bam, vcfout=vcfout.name,
                             chunk_size=100000, pad=25, dpsp=True)
            annotate_vcf_n_reads(args)

            # Read in output variants and compare with expected annotated variants
            vcf_reader = VCFReader(vcfout.name)
            for i, v in enumerate(vcf_reader.fetch()):
                self.assertEqual(v, variants_annotated[i],
                                 'Annotation failed for variant {}: {} {}.'.format(i, v.chrom, v.pos))
Example #7
0
    def test_004_check_merge_multi_bug(self):
        # if we have two indels on one haplotype that cancel each other out
        # (e.g. insertion of a T followed by a deletion of a T)
        # check we don't have an alt that is the same as the ref.
        ref_seq = 'TTTTTTTTTT'
        chrom = 'chrom1'

        h1 = [
            Variant(chrom,
                    0,
                    'TTTTT',
                    alt='T',
                    qual=5,
                    genotype_data={'GT': '1/1'})
        ]
        h2 = [
            Variant(chrom,
                    1,
                    'T',
                    alt='TT',
                    qual=10,
                    genotype_data={'GT': '1/1'}),
            Variant(chrom,
                    3,
                    'TT',
                    alt='T',
                    qual=10,
                    genotype_data={'GT': '1/1'})
        ]
        expt = Variant(chrom,
                       0,
                       'TTTTT',
                       alt='T',
                       qual=5,
                       genotype_data={'GT': '1|0'})

        comb_interval, trees = self.intervaltree_prep(h1, h2, ref_seq)
        # preserve phase otherwise alts could be switched around
        got = _merge_variants(comb_interval,
                              trees,
                              ref_seq,
                              discard_phase=False)
        for key in ('chrom', 'pos', 'ref', 'qual', 'alt', 'gt', 'phased'):
            expected = getattr(expt, key)
            result = getattr(got, key)
            self.assertEqual(
                expected, result,
                'Merging failed for {}:{} {}.'.format(expt.chrom, expt.pos + 1,
                                                      key))
Example #8
0
 def test_classify_variant(self):
     cases = [
         ('snp', 'G', ['A']),
         ('mnp', 'GG', ['AT']),
         ('snd', 'GA', ['G']),
         ('snd', 'GA', ['A']),
         ('mnd', 'GAT', ['G']),
         ('mnd', 'GAA', ['A']),
         ('sni', 'G', ['GT']),
         ('sni', 'G', ['AG']),
         ('mni', 'G', ['GTC']),
         ('mni', 'G', ['ATG']),
         ('other', 'G', ['ATA']),
         ('other', 'GAC', ['T']),
         # Classed as sub + ins, but with alignment one could class it as a mnd
         ('other', 'G', ['TGC']),
         # Classed as sub + del, but with alignment one could class it as a mni
         ('other', 'GAT', ['A']),
         ('indel', 'GG', ['G', 'GGA']),
         ('mnd', 'GGG', ['G', 'GG']),
         ('mni', 'G', ['GG', 'GGG']),
         ('mnp', 'GA', ['CT', 'CA']),
     ]
     for klass, ref, alts in cases:
         var = Variant('20', 14369, ref, alt=alts)
         self.assertEqual(
             klass, classify_variant(var),
             'Classification failed for {} {} {}'.format(ref, alts, klass))
Example #9
0
 def test_get_padded_haplotypes(self):
     chrom = 'my_chrom'
     ref_seq = 'ATGCTACTGC'
     # (pos, ref, alt), pad, padded ref, padded alt, start, end
     cases = [
         ((4, 'T', 'G'), 2, 'GCTAC', 'GCGAC', 2, 7),  #  sub
         ((4, 'T', 'TA'), 2, 'GCTAC', 'GCTAAC', 2, 7),  #  ins
         ((4, 'T', 'GA'), 2, 'GCTAC', 'GCGAAC', 2, 7),  #  sub ins
         ((4, 'TA', 'T'), 2, 'GCTACT', 'GCTCT', 2, 8),  #  del
         ((4, 'TA', 'G'), 2, 'GCTACT', 'GCGCT', 2, 8),  #  sub del
         # test what happens for variant at start and end of chrom
         ((0, 'A', 'G'), 2, 'ATG', 'GTG', 0, 3),  #  sub at start
         ((0, 'A', 'AG'), 2, 'ATG', 'AGTG', 0, 3),  #  ins at start
         ((0, 'AT', 'T'), 2, 'ATGC', 'TGC', 0, 4),  #  del at start
         ((9, 'C', 'G'), 2, 'TGC', 'TGG', 7, 10),  #  sub at end
         ((9, 'C', 'CG'), 2, 'TGC', 'TGCG', 7, 10),  #  ins at end
         ((8, 'GC', 'G'), 2, 'CTGC', 'CTG', 6, 10),  #  del at end
     ]
     for ((pos, ref, alt), pad, pad_ref, pad_alt, start, end) in cases:
         var = Variant(chrom, pos, ref, alt)
         padded, region = get_padded_haplotypes(var, ref_seq, pad)
         self.assertEqual(pad_ref, padded[0])
         self.assertEqual(pad_alt, padded[1])
         self.assertEqual(region.start, start)
         self.assertEqual(region.end, end)
Example #10
0
 def setUpClass(cls):
     cls.base_parameters = {
         'chrom':
         'chr1',
         'pos':
         14369,
         'ref':
         'G',
         'alt': ['A'],
         'ident':
         'rs6054257',
         'qual':
         29,
         'filt':
         'PASS',
         'info': {
             'NS': 3,
             'DP': 14,
             'AF': 0.5,
             'DB': True,
             'H2': True
         },
         'genotype_data':
         OrderedDict([('GT', '1|0'), ('DP', '8'), ('GQ', '48'),
                      ('HQ', '51,51')])
     }
     cls.variant = Variant(**cls.base_parameters)
Example #11
0
def merge_haploid_vcfs(vcf1, vcf2, vcf_out):
    "Merge SNPs from two haploid VCFs into an unphased diploid vcf."
    loci_by_chrom = defaultdict(set)

    vcf1 = VCFReader(vcf1)
    vcf2 = VCFReader(vcf2)

    for v in chain(vcf1.fetch(), vcf2.fetch()):
        loci_by_chrom[v.chrom].add(v.pos)

    with VCFWriter(vcf_out, 'w', version='4.1') as vcf_writer:
        for chrom, loci in loci_by_chrom.items():
            for pos in sorted(loci):
                v1 = list(vcf1.fetch(ref_name=chrom, start=pos, end=pos+1))
                v2 = list(vcf2.fetch(ref_name=chrom, start=pos, end=pos+1))

                # the QC is -10*log10(1-p(label)) where p(label) is the medaka consensus
                # probability. To combine these, we probably want to multiply the
                # (1-p(label)) values, i.e. add the QC scores. However, in the case of a
                # herterozygous SNPs where one of the haplotypes is the reference, we
                # won't have the QC value of the reference haplotype (no variant was
                # called).
                # Hence if we want a common scale we need to assume we can apprimate the missing
                # QC score for the reference haplotypes as being equal to the non-reference
                # haplotype so we can set the overall score to double the latter.
                def get_gq(v1, v2):
                    if len(v1) == 1 and len(v2) == 1:
                        gq = float(v1[0].sample_dict['GQ']) + float(v2[0].sample_dict['GQ'])
                    else:
                        v = v1[0] if len(v1) == 1 else v2[0]
                        gq = 2 * float(v.sample_dict['GQ'])
                    return gq

                def get_ref(v1, v2):
                    return v1[0].ref if len(v1) == 1 else v2[0].ref

                # Note we output unphased GTs as we might have multiple phased
                # regions and the phase can switch between regions

                # heterozygous on v1:
                if len(v1) == 1 and (len(v2) == 0 or v2[0].alt == ['.']):
                    alt = v1[0].alt
                    gt = '0/1'  # not 1/0 by convention since this is unphased
                # heterozygous on v2
                elif (len(v1) == 0 or v1[0].alt == ['.']) and len(v2) == 1:
                    alt = v2[0].alt
                    gt = '0/1'
                else:
                    assert len(v1) == 1 and len(v2) == 1
                    if v1[0].alt == v2[0].alt:  #homozygous snp
                        alt = v1[0].alt
                        gt = '1/1'
                    else:  #heterozygous snp
                        alt = v1[0].alt + v2[0].alt
                        gt = '1/2'

                gq = get_gq(v1, v2)
                v = Variant(chrom, pos, get_ref(v1, v2), alt=alt, qual=gq, sample_dict={'GT':gt, 'GQ':gq})
                vcf_writer.write_variant(v)
Example #12
0
    def test_003_check_merge_multi(self):
        ref_seq = 'ATGGTATGCGATTGACC'
        chrom='chrom1'

        h1 = [Variant(chrom, 0, 'ATG', alt='G', qual=5, genotype_data={'GT':'1/1'}),
              Variant(chrom, 4, 'T', alt='G', qual=5, genotype_data={'GT':'1/1'}),
              Variant(chrom, 7, 'G', alt='GG', qual=5, genotype_data={'GT':'1/1'}),
              Variant(chrom, 9, ref_seq[9], alt=ref_seq[9] + 'T', qual=5, genotype_data={'GT':'1/1'}),
              ]
        h2 = [Variant(chrom, 1, 'T', alt='TT', qual=10, genotype_data={'GT':'1/1'}),
              Variant(chrom, 2, ref_seq[2:10], alt=ref_seq[2], qual=10, genotype_data={'GT':'1/1'}),
              ]

        # POS  0    1   2   3   4   5   6   7   8   9   10
        # REF  A    T   G   G   T   A   T   G   C   G   A
        # H1   -    -   G   G   g   A   T   Gg  C   Gt  A
        # H2   A    Tt  G   -   -   -   -   -   -   -   A

        # expected merged variants
        ref_expt =  "ATGGTATGCG"
        alt1_expt = "GGGATGGCGT"
        alt2_expt = "ATTG"

        expt = Variant(chrom, 0, ref_expt, alt=[alt1_expt, alt2_expt], qual=7.5, genotype_data={'GT': '1|2'})

        comb_interval, trees = self.intervaltree_prep(h1, h2, ref_seq)
        # preserve phase otherwise alts could be switched around
        got = _merge_variants(comb_interval, trees, ref_seq, discard_phase=False)
        for key in  ('chrom', 'pos', 'ref', 'qual', 'alt', 'gt', 'phased'):
            expected = getattr(expt, key)
            result = getattr(got, key)
            self.assertEqual(expected, result, 'Merging failed for {}:{} {}.'.format(expt.chrom, expt.pos+1, key))
Example #13
0
 def _make_variant(self, pos, ref, alt, gt):
     return Variant(self.chrom,
                    pos,
                    ref,
                    alt,
                    genotype_data={
                        'GT': '{}|{}'.format(*gt),
                        'GQ': self.qual
                    },
                    info=self.info)
Example #14
0
    def test_020_inequalities(self):
        """Check equality of two variants."""

        # Create an alternative parameters with all values different
        alternative_parameters = {
            'chrom':
            'chr2',
            'pos':
            1,
            'ref':
            'T',
            'alt': ['C'],
            'ident':
            'rt',
            'qual':
            1,
            'filt':
            '.',
            'info': {
                'NS': 1,
                'DP': 1,
                'AF': 0.1,
                'DB': False,
                'H2': False
            },
            'genotype_data':
            OrderedDict([('GT', '0|0'), ('DP', '7'), ('GQ', '12'),
                         ('HQ', '5,5')])
        }

        # A variant created with the same parameters should be equal,
        # but not the same object
        variant1 = Variant(**self.base_parameters)
        self.assertTrue(id(self.variant) != id(variant1))
        self.assertTrue(self.variant == variant1)

        # Changing one thing at a time
        for changing_key in alternative_parameters.keys():
            new_parameters = self.base_parameters.copy()
            new_parameters[changing_key] = alternative_parameters[changing_key]
            variant1 = Variant(**new_parameters)
            self.assertTrue(variant1 != self.variant)
Example #15
0
    def test_030_genotype_info(self):
        # self.variant has genotype information, variant2 will not have any
        variant2 =  Variant(
            'chr1', 14369, 'G', alt=['A'], ident='rs6054257', qual=29)

        expected_genotype_keys = 'GT:DP:GQ:HQ'  # Resorted keys
        self.assertEqual(self.variant.genotype_keys, expected_genotype_keys)

        expected_genotype_values = '1|0:8:48:51,51'
        self.assertEqual(self.variant.genotype_values, expected_genotype_values)

        expected_gt = (1, 0)
        self.assertEqual(self.variant.gt, expected_gt)

        # If no genotype info is present, expected None
        expected_gt = None
        self.assertEqual(variant2.gt, expected_gt)
Example #16
0
 def test_001_check_trim_start(self):
     v_orig = Variant('20', 14369, 'GGC', alt=['GGA'])
     v_expt = Variant('20', 14371, 'C', alt=['A'])
     v_trim = v_orig.trim()
     self.assertEqual(v_expt, v_trim,
                      'Trimming failed for {}.'.format(v_expt))
Example #17
0
 def test_006_check_trim_ref(self):
     v_orig = Variant('20', 14369, 'CCTG', alt=['C'])
     v_expt = v_orig
     v_trim = v_orig.trim()
     self.assertEqual(v_expt, v_trim,
                      'Trimming failed for {}.'.format(v_expt))
Example #18
0
 def test_015_empty_info(self):
     params = deepcopy(self.base_parameters)
     del params['info']
     variant = Variant(**params)
     self.assertEqual(variant.info_string, ".")
Example #19
0
 def test_raises(self):
     chrom = 'my_chrom'
     ref_seq = 'ATGCTACTGC'
     var = Variant(chrom, 2, 'GT', 'G')  # ref should be GC
     with self.assertRaises(ValueError):
         get_padded_haplotypes(var, ref_seq, 2)
Example #20
0
 def test_002_check_trim_end(self):
     v_orig = Variant('20', 14369, 'CGG', alt=['AGG'])
     v_expt = Variant('20', 14369, 'C', alt=['A'])
     v_trim = v_orig.trim()
     self.assertEqual(v_expt, v_trim,
                      'Trimming failed for {}.'.format(v_expt))
Example #21
0
def find_snps(probs_hdfs, ref_fasta, out_file, regions=None, threshold=0.1, ref_vcf=None):
    """Find potential homozygous and heterozygous snps based on a probability threshold.

    :param probs_hdfs: iterable of hdf filepaths.
    :param ref_fasta: reference fasta.
    :param out_file: vcf output file.
    :param threshold: threshold below which a secondary call (which would make
            for a heterozygous call) is deemed insignificant.
    :param ref_vcf: input vcf to force evaluation only at these loci, even if we would
            not otherwise call a SNP there.
    """
    logger = medaka.common.get_named_logger('SNPs')

    index = DataIndex(probs_hdfs)

    label_decoding = index.meta['medaka_label_decoding']
    # include extra bases so we can encode e.g. N's in reference
    label_encoding = {label: ind for ind, label in enumerate(label_decoding + list(medaka.common._extra_bases_))}

    fmt_feat = lambda x: '{}{}'.format('rev' if x[0] else 'fwd', x[2] * (x[1] if x[1] is not None else '-'))
    feature_row_names = [fmt_feat(x) for x in index.meta['medaka_feature_decoding']]

    logger.debug("Label decoding is:\n{}".format('\n'.join(
        '{}: {}'.format(i, x) for i, x in enumerate(label_decoding)
    )))
    if regions is None:
        ref_names = index.index.keys()
    else:
        #TODO: respect entire region specification
        ref_names = list()
        for region in (medaka.common.Region.from_string(r) for r in regions):
            if region.start is None or region.end is None:
                logger.warning("Ignoring start:end for '{}'.".format(region))
            ref_names.append(region.ref_name)


    def _get_ref_variant(ref_vcf, ref_name, pos):
        ref_variants = list(ref_vcf.fetch(ref_name=ref_name, start=pos, end=pos+1))
        assert len(ref_variants) < 2
        if len(ref_variants) == 0:
            ref_info = {'ref_alt': 'na', 'ref_gt': 'na'}
        else:
            ref_info = {'ref_alt': ','.join(ref_variants[0].alt),
                        'ref_gt': ref_variants[0].sample_dict['GT']}
        return ref_info

    if ref_vcf is not None:
        ref_vcf = medaka.vcf.VCFReader(ref_vcf)
        vcf_loci = defaultdict(set)
        for v in ref_vcf.fetch():
            vcf_loci[v.chrom].add(v.pos)

    # For SNPS, we assume we just want the label probabilities at major positions
    # We need to handle boundary effects simply to make sure we take the best
    # probability (which is furthest from the boundary).
    with VCFWriter(out_file, 'w', version='4.1') as vcf_writer:
        for ref_name in ref_names:
            called_loci = set()
            ref_seq = pysam.FastaFile(ref_fasta).fetch(reference=ref_name)
            logger.info("Processing {}.".format(ref_name))
            # TODO: refactor this and stitch to use common func/generator to get
            # chunks with overlapping bits trimmed off
            data_gen = index.yield_from_feature_files(ref_names=(ref_name,))
            s1 = next(data_gen)
            start_1_ind = None  # don't trim beginning of s1
            for s2 in chain(data_gen, (None,)):
                if s2 is None:  # s1 is last chunk
                    end_1_ind = None  # go to the end of s2
                else:
                    end_1_ind, start_2_ind = medaka.common.get_sample_overlap(s1, s2)

                pos = s1.positions[start_1_ind:end_1_ind]
                probs = s1.label_probs[start_1_ind:end_1_ind]
                # discard minor positions (insertions)
                major_inds = np.where(pos['minor'] == 0)
                major_pos = pos[major_inds]['major']
                major_probs = probs[major_inds]
                major_feat = s1.features[start_1_ind:end_1_ind][major_inds] if s1.features is not None else None


                # for homozygous SNP max_prob_label not in {ref, del} and
                # (2nd_max_prob < threshold or 2nd_max_prob_label is del)
                # for heterozygous SNP 2nd_max_prob > threshold and del not in {max_prob_label, 2nd_prob_label}
                # this catches both SNPs where the genotype contains the
                # reference, and where both copies are mutated.

                sorted_prob_inds = np.argsort(major_probs, -1)
                sorted_probs = np.take_along_axis(major_probs, sorted_prob_inds, axis=-1)
                primary_labels = sorted_prob_inds[:, -1]
                secondary_labels = sorted_prob_inds[:, -2]
                primary_probs = sorted_probs[:, -1]
                secondary_probs = sorted_probs[:, -2]
                # skip positions where ref is not a label (ATGC)
                ref_seq_encoded = np.fromiter((label_encoding[ref_seq[i]] for i in major_pos), int, count=len(major_pos))
                is_ref_valid_label = np.isin(ref_seq_encoded, np.arange(len(label_decoding)))

                # homozygous SNPs
                is_primary_diff_to_ref = np.not_equal(primary_labels, ref_seq_encoded)
                is_primary_not_del = primary_labels != label_encoding[medaka.common._gap_]
                is_secondary_del = secondary_labels == label_encoding[medaka.common._gap_]
                is_secondary_prob_lt_thresh = secondary_probs < threshold
                is_not_secondary_call = np.logical_or(is_secondary_del, is_secondary_prob_lt_thresh)
                is_homozygous_snp = np.logical_and(is_primary_diff_to_ref, is_primary_not_del)
                is_homozygous_snp = np.logical_and(is_homozygous_snp, is_not_secondary_call)
                is_homozygous_snp = np.logical_and(is_homozygous_snp, is_ref_valid_label)
                homozygous_snp_inds = np.where(is_homozygous_snp)
                # heterozygous SNPs
                is_secondary_prob_ge_thresh = np.logical_not(is_secondary_prob_lt_thresh)
                is_secondary_not_del = secondary_labels != label_encoding[medaka.common._gap_]
                is_heterozygous_snp = np.logical_and(is_secondary_prob_ge_thresh, is_secondary_not_del)
                is_heterozygous_snp = np.logical_and(is_heterozygous_snp, is_primary_not_del)
                is_heterozygous_snp = np.logical_and(is_heterozygous_snp, is_ref_valid_label)
                heterozygous_snp_inds = np.where(is_heterozygous_snp)
                variants = []
                for i in homozygous_snp_inds[0]:
                    ref_base_encoded = ref_seq_encoded[i]
                    info = {'ref_prob': major_probs[i][ref_base_encoded],
                            'primary_prob': primary_probs[i],
                            'primary_label': label_decoding[primary_labels[i]],
                            'secondary_prob': secondary_probs[i],
                            'secondary_label': label_decoding[secondary_labels[i]],
                            }
                    if ref_vcf is not None:
                        ref_info = _get_ref_variant(ref_vcf, ref_name, major_pos[i])
                        info.update(ref_info)
                    if major_feat is not None:
                        info.update(dict(zip(feature_row_names, major_feat[i])))

                    qual = -10 * np.log10(1 - primary_probs[i])
                    sample = {'GT': '1/1', 'GQ': qual,}
                    variants.append(Variant(ref_name, major_pos[i], label_decoding[ref_base_encoded],
                                      alt=label_decoding[primary_labels[i]],
                                      filter='PASS', info=info, qual=qual, sample_dict=sample))

                for i in heterozygous_snp_inds[0]:
                    ref_base_encoded = ref_seq_encoded[i]
                    info = {'ref_prob': major_probs[i][ref_base_encoded],
                            'primary_prob': primary_probs[i],
                            'primary_label': label_decoding[primary_labels[i]],
                            'secondary_prob': secondary_probs[i],
                            'secondary_label': label_decoding[secondary_labels[i]]
                            }
                    if ref_vcf is not None:
                        ref_info = _get_ref_variant(ref_vcf, ref_name, major_pos[i])
                        info.update(ref_info)
                    if major_feat is not None:
                        info.update(dict(zip(feature_row_names, major_feat[i])))

                    qual = -10 * np.log10(1 - primary_probs[i] - secondary_probs[i])
                    alt = [label_decoding[l] for l in (primary_labels[i], secondary_labels[i]) if l != ref_base_encoded]
                    gt = '0/1' if len(alt) == 1 else '1/2'  # / => unphased
                    sample = {'GT': gt, 'GQ': qual,}
                    variants.append(Variant(ref_name, major_pos[i], label_decoding[ref_base_encoded],
                                      alt=alt, filter='PASS', info=info, qual=qual, sample_dict=sample))

                if ref_vcf is not None:
                    # if we provided a vcf, check which positions are missing
                    called_loci.update({v.pos for v in variants})
                    missing_loci = vcf_loci[ref_name] - called_loci
                    missing_loci_in_chunk = missing_loci.intersection(major_pos)
                    missing_loci_in_chunk = np.fromiter(missing_loci_in_chunk, int, count=len(missing_loci_in_chunk))
                    is_missing = np.isin(major_pos, missing_loci_in_chunk)
                    missing_snp_inds = np.where(is_missing)

                    for i in missing_snp_inds[0]:
                        ref_base_encoded = ref_seq_encoded[i]
                        info = {'ref_prob': major_probs[i][ref_base_encoded],
                                'primary_prob': primary_probs[i],
                                'primary_label': label_decoding[primary_labels[i]],
                                'secondary_prob': secondary_probs[i],
                                'secondary_label': label_decoding[secondary_labels[i]],
                                }
                        ref_info = _get_ref_variant(ref_vcf, ref_name, major_pos[i])
                        info.update(ref_info)
                        if major_feat is not None:
                            info.update(dict(zip(feature_row_names, major_feat[i])))
                        qual = -10 * np.log10(1 - primary_probs[i])
                        sample = {'GT': 0, 'GQ': qual,}
                        variants.append(Variant(ref_name, major_pos[i], label_decoding[ref_base_encoded],
                                          alt='.',
                                          filter='PASS', info=info, qual=qual, sample_dict=sample))

                sorter = lambda v: v.pos
                variants.sort(key=sorter)
                for variant in variants:
                    vcf_writer.write_variant(variant)

                if end_1_ind is None:
                    if start_2_ind is None:
                        msg = 'There is no overlap betwen {} and {}'
                        logger.info(msg.format(s1.name, s2.name))
                s1 = s2
                start_1_ind = start_2_ind
Example #22
0
 def test_005_check_trim_ref(self):
     v_orig = Variant('20', 14369, 'ATCGG', alt=['ATAGG', 'ATGGG'])
     v_expt = Variant('20', 14371, 'C', alt=['A', 'G'])
     v_trim = v_orig.trim()
     self.assertEqual(v_expt, v_trim,
                      'Trimming failed for {}.'.format(v_expt))