Exemplo n.º 1
0
def combine_pdfs(BP, c, use_product, weighting_scheme):

    L = []
    R = []
    for b_i in c:
        b = BP[b_i]
        L.append([b.left.start, b.left.end, b.left.p])
        R.append([b.right.start, b.right.end, b.right.p])

    [start_R, end_R, a_R] = l_bp.align_intervals(R)
    [start_L, end_L, a_L] = l_bp.align_intervals(L)

    p_L = [0] * len(a_L[0])
    p_R = [0] * len(a_R[0])
    wts = [1] * len(c)

    for c_i in range(len(c)):

        if weighting_scheme == 'evidence_wt':

            A = BP[c[c_i]].l.rstrip().split('\t', 10)
            m = l_bp.to_map(A[7])
            wt = int(m['SU'])
            #sys.stderr.write("wt\t0\t"+str(wt)+"\n")
            a_L[c_i] = [wt * ali for ali in a_L[c_i]]
            a_R[c_i] = [wt * ari for ari in a_R[c_i]]

        elif weighting_scheme == 'carrier_wt':

            A = BP[c[c_i]].l.rstrip().split('\t', 10)
            m = l_bp.to_map(A[7])
            wt = 1
            if 'SNAME' in m:
                wt = len(m['SNAME'].split(','))
            a_L[c_i] = [wt * ali for ali in a_L[c_i]]
            a_R[c_i] = [wt * ari for ari in a_R[c_i]]

        for i in range(len(a_L[c_i])):
            #sys.stderr.write("L\t"+str(i)+"\t"+str(c_i)+"\t"+str(a_L[c_i][i])+"\n")
            p_L[i] += a_L[c_i][i]

        for i in range(len(a_R[c_i])):
            #sys.stderr.write("R\t"+str(i)+"\t"+str(c_i)+"\t"+str(a_R[c_i][i])+"\n")
            p_R[i] += a_R[c_i][i]

    ALG = 'SUM'
    if use_product:
        pmax_i_L = p_L.index(max(p_L))
        pmax_i_R = p_R.index(max(p_R))

        miss = 0
        for c_i in range(len(c)):
            if (a_L[c_i][pmax_i_L] == 0) or (a_R[c_i][pmax_i_R] == 0):
                miss += 1
        if miss == 0:
            ALG = "PROD"
            ls_p_L = [ls.get_ls(1)] * len(a_L[0])
            ls_p_R = [ls.get_ls(1)] * len(a_R[0])

            for c_i in range(len(c)):
                for i in range(len(a_L[c_i])):
                    ls_p_L[i] = ls.ls_multiply(ls_p_L[i],
                                               ls.get_ls(a_L[c_i][i]))

                for i in range(len(a_R[c_i])):
                    ls_p_R[i] = ls.ls_multiply(ls_p_R[i],
                                               ls.get_ls(a_R[c_i][i]))

            ls_sum_L = ls.get_ls(0)
            ls_sum_R = ls.get_ls(0)

            for ls_p in ls_p_L:
                ls_sum_L = ls.ls_add(ls_sum_L, ls_p)

            for ls_p in ls_p_R:
                ls_sum_R = ls.ls_add(ls_sum_R, ls_p)

            p_L = []
            for ls_p in ls_p_L:
                p_L.append(ls.get_p(ls.ls_divide(ls_p, ls_sum_L)))

            p_R = []
            for ls_p in ls_p_R:
                p_R.append(ls.get_p(ls.ls_divide(ls_p, ls_sum_R)))

    sum_L = sum(p_L)
    sum_R = sum(p_R)
    p_L = [x / sum_L for x in p_L]
    p_R = [x / sum_L for x in p_R]

    [clip_start_L, clip_end_L] = l_bp.trim(p_L)
    [clip_start_R, clip_end_R] = l_bp.trim(p_R)

    [new_start_L, new_end_L] = [start_L + clip_start_L, end_L - clip_end_L]
    [new_start_R, new_end_R] = [start_R + clip_start_R, end_R - clip_end_R]

    p_L = p_L[clip_start_L:len(p_L) - clip_end_L]
    p_R = p_R[clip_start_R:len(p_R) - clip_end_R]

    s_p_L = sum(p_L)
    s_p_R = sum(p_R)

    p_L = [x / s_p_L for x in p_L]
    p_R = [x / s_p_R for x in p_R]

    #sys.exit(1)
    return new_start_L, new_start_R, p_L, p_R, ALG
Exemplo n.º 2
0
def merge(BP, sample_order, v_id, use_product):
    if len(BP) == 1:
        A = BP[0].l.rstrip().split('\t')
        #tack on id to SNAME
        s_start = A[7].find('SNAME=')
        s_end = A[7].find(';', s_start)
        if (s_end > -1):
            A[7] = A[7][:s_start] + \
                    A[7][s_start:s_end] + \
                    ':' + A[2] + \
                    A[7][s_end:]
        else:
            A[7] += ':' + A[2]

        # reset the id to be unique in this file
        v_id += 1
        A[2] = str(v_id)

        #clip out old mate id
        s_start = A[7].find('MATEID=')
        s_end = A[7].find(';', s_start)
        if (s_end > -1):
            A[7] = A[7][:s_start] + A[7][s_end + 1:]
        elif (s_start > -1):
            A[7] = A[7][:s_start]

        #clip out old event id
        s_start = A[7].find('EVENT=')
        s_end = A[7].find(';', s_start)
        if (s_end > -1):
            A[7] = A[7][:s_start] + A[7][s_end + 1:]
        elif (s_start > -1):
            A[7] = A[7][:s_start]

        #add new mate
        A[7] += ';EVENT=' + A[2]

        #add new alg
        if use_product:
            A[7] += ';ALG=PROD'
        else:
            A[7] += ';ALG=SUM'

        print_var_line('\t'.join(A))
        return v_id

    #Sweep the set.  Find the largest intersecting set.  Remove it.  Continue.
    import heapq

    BP.sort(key=lambda x: x.start_l)

    BP_i = range(len(BP))  # index set of each node in the graph
    C = []

    while len(BP_i) > 0:
        h_l = [
        ]  #heap of left breakpoint end coordinates and node id (index). heapq is a min heap and the end coord is what will be used for the sorting.
        max_c = []
        max_c_len = 0
        for i in BP_i:
            # remove anything in the heap that doesn't intersect with the current breakpoint
            while (len(h_l) > 0) and (h_l[0][0] < BP[i].start_l):
                heapq.heappop(h_l)

            heapq.heappush(h_l, (BP[i].end_l, i))  # add to the heap

            # at this point everything in h_l intersects on the left
            # but we need to take into account what is going on on the right
            h_r = []  # heap with rightmost starts
            h_l_i = [x[1] for x in h_l
                     ]  # this is all of the node ids on the heap currently
            h_l_i.sort(
                key=lambda x: BP[x].start_r)  # sort them by their right start
            for j in h_l_i:
                # remove anything in the heap that doesn't intersect with the current breakpoint on the right end
                while (len(h_r) > 0) and (h_r[0][0] < BP[j].start_r):
                    heapq.heappop(h_r)

                # add something to the right heap
                heapq.heappush(h_r, (BP[j].end_r, j))

                if max_c_len < len(h_r):
                    # max clique! Register what nodes we have
                    max_c_len = len(h_r)
                    max_c = [y[1] for y in h_r]

        C.append(max_c)
        for c in max_c:
            BP_i.remove(c)

    for c in C:
        L = []
        R = []
        for b_i in c:
            b = BP[b_i]
            L.append([b.start_l, b.end_l, b.p_l])
            R.append([b.start_r, b.end_r, b.p_r])

        [start_R, end_R, a_R] = l_bp.align_intervals(R)
        [start_L, end_L, a_L] = l_bp.align_intervals(L)

        p_L = [0] * len(a_L[0])
        p_R = [0] * len(a_R[0])

        for c_i in range(len(c)):
            for i in range(len(a_L[c_i])):
                p_L[i] += a_L[c_i][i]

            for i in range(len(a_R[c_i])):
                p_R[i] += a_R[c_i][i]

        ALG = 'SUM'
        if use_product:
            pmax_i_L = p_L.index(max(p_L))
            pmax_i_R = p_R.index(max(p_R))

            miss = 0
            for c_i in range(len(c)):
                if (a_L[c_i][pmax_i_L] == 0) or (a_R[c_i][pmax_i_R] == 0):
                    miss += 1
            if miss == 0:
                ALG = "PROD"
                ls_p_L = [ls.get_ls(1)] * len(a_L[0])
                ls_p_R = [ls.get_ls(1)] * len(a_R[0])
                for c_i in range(len(c)):
                    for i in range(len(a_L[c_i])):
                        ls_p_L[i] = ls.ls_multiply(ls_p_L[i],
                                                   ls.get_ls(a_L[c_i][i]))

                    for i in range(len(a_R[c_i])):
                        ls_p_R[i] = ls.ls_multiply(ls_p_R[i],
                                                   ls.get_ls(a_R[c_i][i]))

                ls_sum_L = ls.get_ls(0)
                ls_sum_R = ls.get_ls(0)

                for ls_p in ls_p_L:
                    ls_sum_L = ls.ls_add(ls_sum_L, ls_p)

                for ls_p in ls_p_R:
                    ls_sum_R = ls.ls_add(ls_sum_R, ls_p)

                p_L = []
                for ls_p in ls_p_L:
                    p_L.append(ls.get_p(ls.ls_divide(ls_p, ls_sum_L)))

                p_R = []
                for ls_p in ls_p_R:
                    p_R.append(ls.get_p(ls.ls_divide(ls_p, ls_sum_R)))

        sum_L = sum(p_L)
        sum_R = sum(p_R)
        p_L = [x / sum_L for x in p_L]
        p_R = [x / sum_L for x in p_R]

        [clip_start_L, clip_end_L] = l_bp.trim(p_L)
        [clip_start_R, clip_end_R] = l_bp.trim(p_R)

        new_start_L = start_L + clip_start_L
        new_end_L = end_L - clip_end_L

        new_start_R = start_R + clip_start_R
        new_end_R = end_R - clip_end_R

        p_L = p_L[clip_start_L:len(p_L) - clip_end_L]
        p_R = p_R[clip_start_R:len(p_R) - clip_end_R]

        s_p_L = sum(p_L)
        s_p_R = sum(p_R)

        p_L = [x / s_p_L for x in p_L]
        p_R = [x / s_p_R for x in p_R]

        max_i_L = p_L.index(max(p_L))
        max_i_R = p_R.index(max(p_R))

        ninefive_i_L_start = max_i_L
        ninefive_i_L_end = max_i_L
        ninefive_i_L_total = p_L[max_i_L]
        updated = 0
        while (ninefive_i_L_total < 0.95):
            if (ninefive_i_L_start <= 0) and (ninefive_i_L_end >=
                                              (len(p_L) - 1)):
                break
            ninefive_i_L_start = max(0, ninefive_i_L_start - 1)
            ninefive_i_L_end = min(len(p_L) - 1, ninefive_i_L_end + 1)
            ninefive_i_L_total = sum(p_L[ninefive_i_L_start:ninefive_i_L_end +
                                         1])
        ninefive_i_L_start = ninefive_i_L_start - max_i_L
        ninefive_i_L_end = ninefive_i_L_end - max_i_L

        ninefive_i_R_start = max_i_R
        ninefive_i_R_end = max_i_R
        ninefive_i_R_total = p_R[max_i_R]
        updated = 0
        while (ninefive_i_R_total < 0.95):
            if (ninefive_i_R_start <= 0) and (ninefive_i_R_end >=
                                              len(p_R) - 1):
                break
            ninefive_i_R_start = max(0, ninefive_i_R_start - 1)
            ninefive_i_R_end = min(len(p_R) - 1, ninefive_i_R_end + 1)
            ninefive_i_R_total = sum(p_R[ninefive_i_R_start:ninefive_i_R_end +
                                         1])
        ninefive_i_R_end = ninefive_i_R_end - max_i_R
        ninefive_i_R_start = ninefive_i_R_start - max_i_R
        CIPOS95 = str(ninefive_i_L_start) + ',' + str(ninefive_i_L_end)
        CIEND95 = str(ninefive_i_R_start) + ',' + str(ninefive_i_R_end)

        CHROM = BP[c[0]].chr_l
        POS = new_start_L + max_i_L
        v_id += 1
        ID = str(v_id)
        REF = 'N'

        ALT = ''
        if BP[c[0]].sv_type == 'BND':
            if BP[c[0]].strands[:2] == '++':
                ALT = 'N]' + \
                        BP[c[0]].chr_r + \
                        ':' + \
                        str(new_start_R + max_i_R) + \
                        ']'
            elif BP[c[0]].strands[:2] == '-+':
                ALT = ']' + \
                        BP[c[0]].chr_r + \
                        ':' + \
                        str(new_start_R + max_i_R) + \
                        ']N'
            elif BP[c[0]].strands[:2] == '+-':
                ALT = 'N[' + \
                        BP[c[0]].chr_r + \
                        ':' + \
                        str(new_start_R + max_i_R) + \
                        '['
            elif BP[c[0]].strands[:2] == '--':
                ALT = '[' + \
                        BP[c[0]].chr_r + \
                        ':' + \
                        str(new_start_R + max_i_R) + \
                        '[N'

        else:
            ALT = '<' + BP[c[0]].sv_type + '>'
        QUAL = 0.0
        FILTER = '.'
        FORMAT = BP[c[0]].l.split('\t')[8]
        SVTYPE = BP[c[0]].sv_type

        STRANDS = ''
        strand_map = {}
        e_type_map = {}

        SU = 0
        PE = 0
        SR = 0

        s_name_list = []

        gt_list = []

        for b_i in c:
            A = BP[b_i].l.rstrip().split('\t')
            if A[5].isdigit():
                QUAL += float(A[5])

            m = l_bp.to_map(A[7])

            for strand_entry in m['STRANDS'].split(','):
                s_type, s_count = strand_entry.split(':')
                if s_type not in strand_map:
                    strand_map[s_type] = 0
                strand_map[s_type] += int(s_count)

            SU += int(m['SU'])
            PE += int(m['PE'])
            SR += int(m['SR'])

            s_name_list.append(m['SNAME'] + ':' + A[2])

            gt_list += A[9:]

        SNAME = ','.join(s_name_list)

        GTS = '\t'.join(gt_list)

        strand_types_counts = []
        for strand in strand_map:
            strand_types_counts.append(strand + ':' + str(strand_map[strand]))
        STRANDS = ','.join(strand_types_counts)

        if SVTYPE == 'DEL':
            SVLEN = (new_start_L + max_i_L) - (new_start_R + max_i_R)
        else:
            SVLEN = (new_start_R + max_i_R) - (new_start_L + max_i_L)

        # Don't set SVLEN if we have an interchromosomal event. Doesn't make any sense.
        if BP[c[0]].chr_l != BP[c[0]].chr_r:
            SVLEN = None

        END = new_start_R + max_i_R
        CIPOS = ','.join(
            [str(x)
             for x in [-1 * max_i_L, len(p_L) - max_i_L - 1]])
        CIEND = ','.join(
            [str(x)
             for x in [-1 * max_i_R, len(p_R) - max_i_R - 1]])
        IMPRECISE = 'IMPRECISE'
        PRPOS = ','.join([str(x) for x in p_L])
        PREND = ','.join([str(x) for x in p_R])


        if (int(CIPOS.split(',')[0]) > int(CIPOS95.split(',')[0])) or \
            (int(CIPOS.split(',')[1]) < int(CIPOS95.split(',')[1])) or \
            (int(CIEND.split(',')[0]) > int(CIEND95.split(',')[0])) or \
            (int(CIEND.split(',')[1]) < int(CIEND95.split(',')[1])):
            sys.stderr.write(CIPOS + "\t" + str(CIPOS95) + "\n")
            sys.stderr.write(CIEND + "\t" + str(CIEND95) + "\n")

        I = ['SVTYPE=' + str(SVTYPE), 'STRANDS=' + str(STRANDS)]
        if SVLEN:
            I += ['SVLEN=' + str(SVLEN)]
        I += [
            'CIPOS=' + str(CIPOS), 'CIEND=' + str(CIEND),
            'CIPOS95=' + str(CIPOS95), 'CIEND95=' + str(CIEND95),
            str(IMPRECISE), 'SU=' + str(SU), 'PE=' + str(PE), 'SR=' + str(SR),
            'PRPOS=' + str(PRPOS), 'PREND=' + str(PREND), 'ALG=' + str(ALG),
            'SNAME=' + str(SNAME)
        ]

        if BP[c[0]].sv_type == 'BND':
            I.append('EVENT=' + str(ID))
        else:
            I.append('END=' + str(END))

        INFO = ';'.join(I)

        QUAL = str(QUAL)

        O = [CHROM, POS, ID, REF, ALT, QUAL, FILTER, INFO]

        print_var_line('\t'.join([str(o) for o in O]))
    return v_id
Exemplo n.º 3
0
def combine_pdfs(BP, c, use_product, weighting_scheme):

    L = []
    R = []
    for b_i in c:
        b = BP[b_i]
        L.append([b.left.start, b.left.end, b.left.p])
        R.append([b.right.start, b.right.end, b.right.p])

    [start_R, end_R, a_R] = l_bp.align_intervals(R)
    [start_L, end_L, a_L] = l_bp.align_intervals(L)

    p_L = [0] * len(a_L[0])
    p_R = [0] * len(a_R[0])
    wts = [1] * len(c)

    for c_i in range(len(c)):

        if weighting_scheme == 'evidence_wt':

            A = BP[c[c_i]].l.rstrip().split('\t', 10)
            m = l_bp.to_map(A[7])
            wt=int(m['SU'])
            #sys.stderr.write("wt\t0\t"+str(wt)+"\n")
            a_L[c_i]=[wt*ali for ali in a_L[c_i]]
            a_R[c_i]=[wt*ari for ari in a_R[c_i]]

        elif weighting_scheme == 'carrier_wt':

            A = BP[c[c_i]].l.rstrip().split('\t', 10)
            m = l_bp.to_map(A[7])
            wt = 1
            if 'SNAME' in m:
                wt=len(m['SNAME'].split(','))
            a_L[c_i]=[wt*ali for ali in a_L[c_i]]
            a_R[c_i]=[wt*ari for ari in a_R[c_i]]

        for i in range(len(a_L[c_i])):
            #sys.stderr.write("L\t"+str(i)+"\t"+str(c_i)+"\t"+str(a_L[c_i][i])+"\n")
            p_L[i] += a_L[c_i][i]

        for i in range(len(a_R[c_i])):
            #sys.stderr.write("R\t"+str(i)+"\t"+str(c_i)+"\t"+str(a_R[c_i][i])+"\n")
            p_R[i] += a_R[c_i][i]

    ALG = 'SUM'
    if use_product:
        pmax_i_L = p_L.index(max(p_L))
        pmax_i_R = p_R.index(max(p_R))

        miss = 0
        for c_i in range(len(c)):
            if (a_L[c_i][pmax_i_L] == 0) or (a_R[c_i][pmax_i_R] == 0):
                miss += 1
        if miss == 0:
            ALG = "PROD"
            ls_p_L = [ls.get_ls(1)] * len(a_L[0])
            ls_p_R = [ls.get_ls(1)] * len(a_R[0])

            for c_i in range(len(c)):
                for i in range(len(a_L[c_i])):
                    ls_p_L[i] = ls.ls_multiply(ls_p_L[i], ls.get_ls(a_L[c_i][i]))

                for i in range(len(a_R[c_i])):
                    ls_p_R[i] = ls.ls_multiply(ls_p_R[i], ls.get_ls(a_R[c_i][i]))

            ls_sum_L = ls.get_ls(0)
            ls_sum_R = ls.get_ls(0)

            for ls_p in ls_p_L:
                ls_sum_L = ls.ls_add(ls_sum_L, ls_p)

            for ls_p in ls_p_R:
                ls_sum_R = ls.ls_add(ls_sum_R, ls_p)

            p_L = []
            for ls_p in ls_p_L:
                p_L.append(ls.get_p(ls.ls_divide(ls_p, ls_sum_L)))

            p_R = []
            for ls_p in ls_p_R:
                p_R.append(ls.get_p(ls.ls_divide(ls_p, ls_sum_R)))

    sum_L = sum(p_L)
    sum_R = sum(p_R)
    p_L = [x/sum_L for x in p_L]
    p_R = [x/sum_L for x in p_R]

    [clip_start_L, clip_end_L] = l_bp.trim(p_L)
    [clip_start_R, clip_end_R] = l_bp.trim(p_R)

    [ new_start_L, new_end_L ] = [ start_L + clip_start_L,  end_L - clip_end_L ]
    [ new_start_R, new_end_R ] = [ start_R + clip_start_R, end_R - clip_end_R ]

    p_L = p_L[clip_start_L:len(p_L)-clip_end_L]
    p_R = p_R[clip_start_R:len(p_R)-clip_end_R]

    s_p_L = sum(p_L)
    s_p_R = sum(p_R)

    p_L = [x/s_p_L for x in p_L]
    p_R = [x/s_p_R for x in p_R]

    #sys.exit(1)
    return new_start_L, new_start_R, p_L, p_R, ALG
Exemplo n.º 4
0
def merge(BP, sample_order, v_id, use_product):
    if len(BP) == 1:
        A = BP[0].l.rstrip().split('\t')
        #tack on id to SNAME
        s_start=A[7].find('SNAME=')
        s_end=A[7].find(';',s_start)
        if (s_end > -1):
            A[7] = A[7][:s_start] + \
                    A[7][s_start:s_end] + \
                    ':' + A[2] + \
                    A[7][s_end:]
        else:
            A[7]+= ':' + A[2]

        # reset the id to be unique in this file
        v_id += 1
        A[2] = str(v_id)

        #clip out old mate id
        s_start=A[7].find('MATEID=')
        s_end=A[7].find(';',s_start)
        if (s_end > -1):
            A[7] = A[7][:s_start] + A[7][s_end+1:]
        elif (s_start > -1):
            A[7] = A[7][:s_start]

        #clip out old event id
        s_start=A[7].find('EVENT=')
        s_end=A[7].find(';', s_start)
        if (s_end > -1):
            A[7] = A[7][:s_start] + A[7][s_end+1:]
        elif (s_start > -1):
            A[7] = A[7][:s_start]

        #add new mate
        A[7]+= ';EVENT=' + A[2]

        #add new alg
        if use_product:
            A[7]+= ';ALG=PROD'
        else:
            A[7] += ';ALG=SUM'
 
        print_var_line('\t'.join(A))
        return v_id

    #Sweep the set.  Find the largest intersecting set.  Remove it.  Continue.
    import heapq

    BP.sort(key=lambda x: x.start_l)

    BP_i = range(len(BP))
    C = []

    while len(BP_i) > 0:
        h_l = []
        max_c = []
        max_c_len = 0
        for i in BP_i:
            while (len(h_l) > 0) and (h_l[0][0] < BP[i].start_l):
                heapq.heappop(h_l)

            heapq.heappush(h_l, (BP[i].end_l, i))

            # at this point everything in h_l intersects on the left
            # but we need to take into account what is going on on the right 
            h_r = []
            h_l_i = [x[1] for x in h_l]
            h_l_i.sort(key=lambda x:BP[x].start_r)
            for j in h_l_i:
                while (len(h_r) > 0) and (h_r[0][0] < BP[j].start_r):
                    heapq.heappop(h_r)

                heapq.heappush(h_r, (BP[j].end_r, j))

                if max_c_len < len(h_r):
                    max_c_len = len(h_r)
                    max_c = [y[1] for y in h_r]

        C.append(max_c)
        for c in max_c:
            BP_i.remove(c)

    for c in C:
        L = []
        R = []
        for b_i in c:
            b = BP[b_i]
            L.append([b.start_l,b.end_l,b.p_l])
            R.append([b.start_r,b.end_r,b.p_r])

        [start_R, end_R, a_R] = l_bp.align_intervals(R)
        [start_L, end_L, a_L] = l_bp.align_intervals(L)

        p_L = [0] * len(a_L[0])
        p_R = [0] * len(a_R[0])

        for c_i in range(len(c)):
            for i in range(len(a_L[c_i])):
                p_L[i] += a_L[c_i][i]

            for i in range(len(a_R[c_i])):
                p_R[i] += a_R[c_i][i]

        ALG = 'SUM'
        if use_product:
            pmax_i_L = p_L.index(max(p_L))
            pmax_i_R = p_R.index(max(p_R))

            miss = 0
            for c_i in range(len(c)):
                if (a_L[c_i][pmax_i_L] == 0) or (a_R[c_i][pmax_i_R] == 0):
                    miss += 1
            if miss == 0:
                ALG = "PROD"
                ls_p_L = [get_ls(1)] * len(a_L[0])
                ls_p_R = [get_ls(1)] * len(a_R[0])
                for c_i in range(len(c)):
                    for i in range(len(a_L[c_i])):
                        ls_p_L[i] = ls_multiply(ls_p_L[i], get_ls(a_L[c_i][i]))

                    for i in range(len(a_R[c_i])):
                        ls_p_R[i] = ls_multiply(ls_p_R[i], get_ls(a_R[c_i][i]))

                ls_sum_L = get_ls(0)
                ls_sum_R = get_ls(0)

                for ls_p in ls_p_L:
                    ls_sum_L = ls_add(ls_sum_L, ls_p)

                for ls_p in ls_p_R:
                    ls_sum_R = ls_add(ls_sum_R, ls_p)

                p_L = []
                for ls_p in ls_p_L:
                    p_L.append(get_p(ls_divide(ls_p, ls_sum_L)))

                p_R = []
                for ls_p in ls_p_R:
                    p_R.append(get_p(ls_divide(ls_p, ls_sum_R)))

        sum_L = sum(p_L)
        sum_R = sum(p_R)
        p_L = [x/sum_L for x in p_L]
        p_R = [x/sum_L for x in p_R]

        [clip_start_L, clip_end_L] = l_bp.trim(p_L)
        [clip_start_R, clip_end_R] = l_bp.trim(p_R)

        new_start_L = start_L + clip_start_L
        new_end_L = end_L - clip_end_L

        new_start_R = start_R + clip_start_R
        new_end_R = end_R - clip_end_R

        p_L = p_L[clip_start_L:len(p_L)-clip_end_L]
        p_R = p_R[clip_start_R:len(p_R)-clip_end_R]

        s_p_L = sum(p_L)
        s_p_R = sum(p_R)

        p_L = [x/s_p_L for x in p_L]
        p_R = [x/s_p_R for x in p_R]

        max_i_L = p_L.index(max(p_L))
        max_i_R = p_R.index(max(p_R))

        ninefive_i_L_start = max_i_L
        ninefive_i_L_end = max_i_L
        ninefive_i_L_total = p_L[max_i_L]
        updated = 0
        while (ninefive_i_L_total < 0.95):
            if (ninefive_i_L_start <= 0) and (ninefive_i_L_end >= (len(p_L)-1)):
                break
            ninefive_i_L_start = max(0, ninefive_i_L_start - 1)
            ninefive_i_L_end = min(len(p_L)-1, ninefive_i_L_end +1)
            ninefive_i_L_total = sum(p_L[ninefive_i_L_start:ninefive_i_L_end+1])
        ninefive_i_L_start = ninefive_i_L_start - max_i_L
        ninefive_i_L_end = ninefive_i_L_end - max_i_L

        ninefive_i_R_start = max_i_R
        ninefive_i_R_end = max_i_R
        ninefive_i_R_total = p_R[max_i_R]
        updated = 0
        while (ninefive_i_R_total < 0.95):
            if (ninefive_i_R_start <= 0) and (ninefive_i_R_end >= len(p_R)-1):
                break
            ninefive_i_R_start = max(0, ninefive_i_R_start - 1)
            ninefive_i_R_end = min(len(p_R)-1, ninefive_i_R_end +1)
            ninefive_i_R_total = sum(p_R[ninefive_i_R_start:ninefive_i_R_end+1])
        ninefive_i_R_end = ninefive_i_R_end - max_i_R
        ninefive_i_R_start = ninefive_i_R_start - max_i_R
        CIPOS95=str(ninefive_i_L_start) + ',' + str(ninefive_i_L_end)
        CIEND95=str(ninefive_i_R_start) + ',' + str(ninefive_i_R_end)

        CHROM = BP[c[0]].chr_l
        POS = new_start_L + max_i_L
        v_id += 1
        ID = str(v_id)
        REF = 'N'

        ALT = ''
        if BP[c[0]].sv_type == 'BND':
            if BP[c[0]].strands[:2] == '++':
                ALT = 'N]' + \
                        BP[c[0]].chr_r + \
                        ':' + \
                        str(new_start_R + max_i_R) + \
                        ']'
            elif BP[c[0]].strands[:2] == '-+':
                ALT = ']' + \
                        BP[c[0]].chr_r + \
                        ':' + \
                        str(new_start_R + max_i_R) + \
                        ']N'
            elif BP[c[0]].strands[:2] == '+-':
                ALT = 'N[' + \
                        BP[c[0]].chr_r + \
                        ':' + \
                        str(new_start_R + max_i_R) + \
                        '['
            elif BP[c[0]].strands[:2] == '--':
                ALT = '[' + \
                        BP[c[0]].chr_r + \
                        ':' + \
                        str(new_start_R + max_i_R) + \
                        '[N'

        else:
            ALT = '<' + BP[c[0]].sv_type + '>'
        QUAL = 0.0
        FILTER = '.'
        FORMAT = BP[c[0]].l.split('\t')[8]
        SVTYPE = BP[c[0]].sv_type

        STRANDS = ''
        strand_map = {}
        e_type_map = {}

        SU = 0
        PE = 0
        SR = 0

        s_name_list = []

        gt_list = [] 

        for b_i in c:
            A = BP[b_i].l.rstrip().split('\t')
            if A[5].isdigit():
                QUAL += float(A[5])

            m = l_bp.to_map(A[7])

            for strand_entry in m['STRANDS'].split(','):
                s_type,s_count = strand_entry.split(':')
                if s_type not in strand_map:
                    strand_map[s_type] = 0
                strand_map[s_type] += int(s_count)

            SU += int(m['SU'])
            PE += int(m['PE'])
            SR += int(m['SR'])

            s_name_list.append(m['SNAME'] + ':' + A[2])

            gt_list += A[9:]

        SNAME=','.join(s_name_list)

        GTS = '\t'.join(gt_list)

        strand_types_counts = []
        for strand in strand_map:
            strand_types_counts.append(strand + ':' + str(strand_map[strand]))
        STRANDS = ','.join(strand_types_counts)

        if SVTYPE=='DEL':
            SVLEN = (new_start_L + max_i_L) - (new_start_R + max_i_R)
        else:
            SVLEN = (new_start_R + max_i_R) - (new_start_L + max_i_L)
        END = new_start_R + max_i_R
        CIPOS=','.join([str(x) for x in [-1*max_i_L, len(p_L) - max_i_L - 1]])
        CIEND=','.join([str(x) for x in [-1*max_i_R, len(p_R) - max_i_R - 1]])
        IMPRECISE='IMPRECISE'
        PRPOS=','.join([str(x) for x in p_L])
        PREND=','.join([str(x) for x in p_R])


        if (int(CIPOS.split(',')[0]) > int(CIPOS95.split(',')[0])) or \
            (int(CIPOS.split(',')[1]) < int(CIPOS95.split(',')[1])) or \
            (int(CIEND.split(',')[0]) > int(CIEND95.split(',')[0])) or \
            (int(CIEND.split(',')[1]) < int(CIEND95.split(',')[1])):
            sys.stderr.write(CIPOS + "\t" + str(CIPOS95) + "\n")
            sys.stderr.write(CIEND + "\t" + str(CIEND95) + "\n")

        I = ['SVTYPE='   + str(SVTYPE),
             'STRANDS='  + str(STRANDS),
             'SVLEN='    + str(SVLEN),
             'CIPOS='    + str(CIPOS),
             'CIEND='    + str(CIEND),
             'CIPOS95='  + str(CIPOS95),
             'CIEND95='  + str(CIEND95),
                           str(IMPRECISE),
             'SU='       + str(SU),
             'PE='       + str(PE),
             'SR='       + str(SR),
             'PRPOS='    + str(PRPOS),
             'PREND='    + str(PREND),
             'ALG='      + str(ALG),
             'SNAME='    + str(SNAME)]

        if BP[c[0]].sv_type == 'BND':
            I.append('EVENT=' + str(ID))
        else:
            I.append('END=' + str(END))

        INFO = ';'.join(I)

        QUAL = str(QUAL)

        O = [CHROM,POS,ID,REF,ALT,QUAL,FILTER,INFO]

        print_var_line('\t'.join([str(o) for o in O]))
    return v_id
Exemplo n.º 5
0
def merge(BP, sample_order, v_id, use_product, include_genotypes=False):
    if len(BP) == 1:
        A = BP[0].l.rstrip().split('\t')
        #tack on id to SNAME
        s_start=A[7].find('SNAME=')
        s_end=A[7].find(';',s_start)
        sname = None
        if (s_end > -1):
            sname = A[7][s_start + 6:s_end]
            A[7] = A[7][:s_start] + \
                    A[7][s_start:s_end] + \
                    ':' + A[2] + \
                    A[7][s_end:]
        else:
            sname = A[7][s_start + 6:]
            A[7]+= ':' + A[2]

        # reset the id to be unique in this file
        v_id += 1
        A[2] = str(v_id)

        #clip out old mate id
        s_start=A[7].find('MATEID=')
        s_end=A[7].find(';',s_start)
        if (s_end > -1):
            A[7] = A[7][:s_start] + A[7][s_end+1:]
        elif (s_start > -1):
            A[7] = A[7][:s_start]

        #clip out old event id
        s_start=A[7].find('EVENT=')
        s_end=A[7].find(';', s_start)
        if (s_end > -1):
            A[7] = A[7][:s_start] + A[7][s_end+1:]
        elif (s_start > -1):
            A[7] = A[7][:s_start]

        #add new mate
        A[7]+= ';EVENT=' + A[2]

        #add new alg
        if use_product:
            A[7]+= ';ALG=PROD'
        else:
            A[7] += ';ALG=SUM'

        GTS = None
        if include_genotypes:
            null_string = null_format_string(A[8])
            gt_dict = { sname: A[9] }
            GTS = '\t'.join([A[8]] + [gt_dict.get(x, null_string) for x in sample_order])
        print_var_line('\t'.join(A), GTS)
        return v_id

    #Sweep the set.  Find the largest intersecting set.  Remove it.  Continue.
    import heapq

    BP.sort(key=lambda x: x.start_l)

    BP_i = range(len(BP)) # index set of each node in the graph
    C = []

    while len(BP_i) > 0:
        h_l = [] #heap of left breakpoint end coordinates and node id (index). heapq is a min heap and the end coord is what will be used for the sorting.
        max_c = []
        max_c_len = 0
        for i in BP_i:
            # remove anything in the heap that doesn't intersect with the current breakpoint
            while (len(h_l) > 0) and (h_l[0][0] < BP[i].start_l):
                heapq.heappop(h_l)

            heapq.heappush(h_l, (BP[i].end_l, i)) # add to the heap

            # at this point everything in h_l intersects on the left
            # but we need to take into account what is going on on the right
            h_r = [] # heap with rightmost starts
            h_l_i = [x[1] for x in h_l] # this is all of the node ids on the heap currently
            h_l_i.sort(key=lambda x:BP[x].start_r) # sort them by their right start
            for j in h_l_i:
                # remove anything in the heap that doesn't intersect with the current breakpoint on the right end
                while (len(h_r) > 0) and (h_r[0][0] < BP[j].start_r):
                    heapq.heappop(h_r)

                # add something to the right heap
                heapq.heappush(h_r, (BP[j].end_r, j))

                if max_c_len < len(h_r):
                    # max clique! Register what nodes we have
                    max_c_len = len(h_r)
                    max_c = [y[1] for y in h_r]

        C.append(max_c)
        for c in max_c:
            BP_i.remove(c)

    for c in C:
        L = []
        R = []
        for b_i in c:
            b = BP[b_i]
            L.append([b.start_l,b.end_l,b.p_l])
            R.append([b.start_r,b.end_r,b.p_r])

        [start_R, end_R, a_R] = l_bp.align_intervals(R)
        [start_L, end_L, a_L] = l_bp.align_intervals(L)

        p_L = [0] * len(a_L[0])
        p_R = [0] * len(a_R[0])

        for c_i in range(len(c)):
            for i in range(len(a_L[c_i])):
                p_L[i] += a_L[c_i][i]

            for i in range(len(a_R[c_i])):
                p_R[i] += a_R[c_i][i]

        ALG = 'SUM'
        if use_product:
            pmax_i_L = p_L.index(max(p_L))
            pmax_i_R = p_R.index(max(p_R))

            miss = 0
            for c_i in range(len(c)):
                if (a_L[c_i][pmax_i_L] == 0) or (a_R[c_i][pmax_i_R] == 0):
                    miss += 1
            if miss == 0:
                ALG = "PROD"
                ls_p_L = [ls.get_ls(1)] * len(a_L[0])
                ls_p_R = [ls.get_ls(1)] * len(a_R[0])
                for c_i in range(len(c)):
                    for i in range(len(a_L[c_i])):
                        ls_p_L[i] = ls.ls_multiply(ls_p_L[i], ls.get_ls(a_L[c_i][i]))

                    for i in range(len(a_R[c_i])):
                        ls_p_R[i] = ls.ls_multiply(ls_p_R[i], ls.get_ls(a_R[c_i][i]))

                ls_sum_L = ls.get_ls(0)
                ls_sum_R = ls.get_ls(0)

                for ls_p in ls_p_L:
                    ls_sum_L = ls.ls_add(ls_sum_L, ls_p)

                for ls_p in ls_p_R:
                    ls_sum_R = ls.ls_add(ls_sum_R, ls_p)

                p_L = []
                for ls_p in ls_p_L:
                    p_L.append(ls.get_p(ls.ls_divide(ls_p, ls_sum_L)))

                p_R = []
                for ls_p in ls_p_R:
                    p_R.append(ls.get_p(ls.ls_divide(ls_p, ls_sum_R)))

        sum_L = sum(p_L)
        sum_R = sum(p_R)
        p_L = [x/sum_L for x in p_L]
        p_R = [x/sum_L for x in p_R]

        [clip_start_L, clip_end_L] = l_bp.trim(p_L)
        [clip_start_R, clip_end_R] = l_bp.trim(p_R)

        new_start_L = start_L + clip_start_L
        new_end_L = end_L - clip_end_L

        new_start_R = start_R + clip_start_R
        new_end_R = end_R - clip_end_R

        p_L = p_L[clip_start_L:len(p_L)-clip_end_L]
        p_R = p_R[clip_start_R:len(p_R)-clip_end_R]

        s_p_L = sum(p_L)
        s_p_R = sum(p_R)

        p_L = [x/s_p_L for x in p_L]
        p_R = [x/s_p_R for x in p_R]

        max_i_L = p_L.index(max(p_L))
        max_i_R = p_R.index(max(p_R))

        ninefive_i_L_start = max_i_L
        ninefive_i_L_end = max_i_L
        ninefive_i_L_total = p_L[max_i_L]
        updated = 0
        while (ninefive_i_L_total < 0.95):
            if (ninefive_i_L_start <= 0) and (ninefive_i_L_end >= (len(p_L)-1)):
                break
            ninefive_i_L_start = max(0, ninefive_i_L_start - 1)
            ninefive_i_L_end = min(len(p_L)-1, ninefive_i_L_end +1)
            ninefive_i_L_total = sum(p_L[ninefive_i_L_start:ninefive_i_L_end+1])
        ninefive_i_L_start = ninefive_i_L_start - max_i_L
        ninefive_i_L_end = ninefive_i_L_end - max_i_L

        ninefive_i_R_start = max_i_R
        ninefive_i_R_end = max_i_R
        ninefive_i_R_total = p_R[max_i_R]
        updated = 0
        while (ninefive_i_R_total < 0.95):
            if (ninefive_i_R_start <= 0) and (ninefive_i_R_end >= len(p_R)-1):
                break
            ninefive_i_R_start = max(0, ninefive_i_R_start - 1)
            ninefive_i_R_end = min(len(p_R)-1, ninefive_i_R_end +1)
            ninefive_i_R_total = sum(p_R[ninefive_i_R_start:ninefive_i_R_end+1])
        ninefive_i_R_end = ninefive_i_R_end - max_i_R
        ninefive_i_R_start = ninefive_i_R_start - max_i_R
        CIPOS95=str(ninefive_i_L_start) + ',' + str(ninefive_i_L_end)
        CIEND95=str(ninefive_i_R_start) + ',' + str(ninefive_i_R_end)

        CHROM = BP[c[0]].chr_l
        POS = new_start_L + max_i_L
        v_id += 1
        ID = str(v_id)
        REF = 'N'

        ALT = ''
        if BP[c[0]].sv_type == 'BND':
            if BP[c[0]].strands[:2] == '++':
                ALT = 'N]' + \
                        BP[c[0]].chr_r + \
                        ':' + \
                        str(new_start_R + max_i_R) + \
                        ']'
            elif BP[c[0]].strands[:2] == '-+':
                ALT = ']' + \
                        BP[c[0]].chr_r + \
                        ':' + \
                        str(new_start_R + max_i_R) + \
                        ']N'
            elif BP[c[0]].strands[:2] == '+-':
                ALT = 'N[' + \
                        BP[c[0]].chr_r + \
                        ':' + \
                        str(new_start_R + max_i_R) + \
                        '['
            elif BP[c[0]].strands[:2] == '--':
                ALT = '[' + \
                        BP[c[0]].chr_r + \
                        ':' + \
                        str(new_start_R + max_i_R) + \
                        '[N'

        else:
            ALT = '<' + BP[c[0]].sv_type + '>'
        QUAL = 0.0
        FILTER = '.'
        FORMAT = BP[c[0]].l.split('\t')[8]
        SVTYPE = BP[c[0]].sv_type

        STRANDS = ''
        strand_map = {}
        e_type_map = {}

        SU = 0
        PE = 0
        SR = 0

        s_name_list = []

        format_string = None
        gt_dict = dict()

        for b_i in c:
            A = BP[b_i].l.rstrip().split('\t')
            if A[5].isdigit():
                QUAL += float(A[5])

            m = l_bp.to_map(A[7])

            for strand_entry in m['STRANDS'].split(','):
                s_type,s_count = strand_entry.split(':')
                if s_type not in strand_map:
                    strand_map[s_type] = 0
                strand_map[s_type] += int(s_count)

            SU += int(m['SU'])
            PE += int(m['PE'])
            SR += int(m['SR'])

            s_name_list.append(m['SNAME'] + ':' + A[2])

            if include_genotypes:
                if format_string == None:
                    format_string = A[8]

                if format_string == A[8]:
                    gt_dict[m['SNAME']] = A[9]
                else:
                    longer = A[8]
                    shorter = format_string
                    if len(longer) < len(shorter):
                        longer, shorter = shorter, longer

                    if longer.find(shorter) == 0:
                        format_string = longer
                        gt_dict[m['SNAME']] = A[9]
                    else:
                        sys.stderr.write('Unable to merge and include genotypes when FORMAT fields differ across VCF files\n')
                        sys.stderr.write('Previous: {0} Current: {1}\n'.format(format_string, A[8]))
                        sys.stderr.write('Variant: {0}\n'.format(m['SNAME'] + ':' + A[2]))
                        sys.exit(1)

        SNAME=','.join(s_name_list)

        GTS = None
        if include_genotypes:
            null_string = null_format_string(format_string)
            GTS = '\t'.join([format_string] + [gt_dict.get(x, null_string) for x in sample_order])

        strand_types_counts = []
        for strand in strand_map:
            strand_types_counts.append(strand + ':' + str(strand_map[strand]))
        STRANDS = ','.join(strand_types_counts)

        if SVTYPE=='DEL':
            SVLEN = (new_start_L + max_i_L) - (new_start_R + max_i_R)
        else:
            SVLEN = (new_start_R + max_i_R) - (new_start_L + max_i_L)

        # Don't set SVLEN if we have an interchromosomal event. Doesn't make any sense.
        if BP[c[0]].chr_l != BP[c[0]].chr_r:
            SVLEN = None

        END = new_start_R + max_i_R
        CIPOS=','.join([str(x) for x in [-1*max_i_L, len(p_L) - max_i_L - 1]])
        CIEND=','.join([str(x) for x in [-1*max_i_R, len(p_R) - max_i_R - 1]])
        IMPRECISE='IMPRECISE'
        PRPOS=','.join([str(x) for x in p_L])
        PREND=','.join([str(x) for x in p_R])


        if (int(CIPOS.split(',')[0]) > int(CIPOS95.split(',')[0])) or \
            (int(CIPOS.split(',')[1]) < int(CIPOS95.split(',')[1])) or \
            (int(CIEND.split(',')[0]) > int(CIEND95.split(',')[0])) or \
            (int(CIEND.split(',')[1]) < int(CIEND95.split(',')[1])):
            sys.stderr.write(CIPOS + "\t" + str(CIPOS95) + "\n")
            sys.stderr.write(CIEND + "\t" + str(CIEND95) + "\n")

        I = ['SVTYPE='   + str(SVTYPE),
             'STRANDS='  + str(STRANDS)
            ]
        if SVLEN is not None:
            I += ['SVLEN='    + str(SVLEN)]
        I += ['CIPOS='    + str(CIPOS),
             'CIEND='    + str(CIEND),
             'CIPOS95='  + str(CIPOS95),
             'CIEND95='  + str(CIEND95),
                           str(IMPRECISE),
             'SU='       + str(SU),
             'PE='       + str(PE),
             'SR='       + str(SR),
             'PRPOS='    + str(PRPOS),
             'PREND='    + str(PREND),
             'ALG='      + str(ALG),
             'SNAME='    + str(SNAME)]

        if BP[c[0]].sv_type == 'BND':
            I.append('EVENT=' + str(ID))
        else:
            I.append('END=' + str(END))

        INFO = ';'.join(I)

        QUAL = str(QUAL)

        O = [CHROM,POS,ID,REF,ALT,QUAL,FILTER,INFO]

        print_var_line('\t'.join([str(o) for o in O]), GTS)
    return v_id