Exemplo n.º 1
0
def sample_dsmc_sprs(k,
                     popsize,
                     rho,
                     recombmap=None,
                     start=0.0,
                     end=0.0,
                     times=None,
                     times2=None,
                     init_tree=None,
                     names=None,
                     make_names=True):
    """
    Sample ARG using Discrete Sequentially Markovian Coalescent (SMC)

    k          -- chromosomes
    popsize    -- effective population size (haploid)
    rho        -- recombination rate (recombinations / site / generation)
    recombmap  -- map for variable recombination rate
    start      -- staring chromosome coordinate
    end        -- ending chromsome coordinate
    t          -- initial time (default: 0)
    names      -- names to use for leaves (default: None)
    make_names -- make names using strings (default: True)
    """

    assert times is not None
    assert times2 is not None
    ntimes = len(times) - 1
    time_steps = [times[i] - times[i - 1] for i in range(1, ntimes + 1)]
    #    times2 = get_coal_times(times)

    if hasattr(popsize, "__len__"):
        popsizes = popsize
    else:
        popsizes = [popsize] * len(time_steps)

    # yield initial tree first
    if init_tree is None:
        init_tree = sample_tree(k,
                                popsizes,
                                times,
                                start=start,
                                end=end,
                                names=names,
                                make_names=make_names)
        argweaver.discretize_arg(init_tree, times2)
    yield init_tree

    # sample SPRs
    pos = start
    tree = init_tree.copy()
    while True:
        # sample next recomb point
        treelen = sum(x.get_dist() for x in tree)
        blocklen = int(
            sample_next_recomb(treelen,
                               rho,
                               pos=pos,
                               recombmap=recombmap,
                               minlen=1))
        pos += blocklen
        if pos >= end - 1:
            break

        root_age_index = times.index(tree.root.age)

        # choose time interval for recombination
        states = set(argweaver.iter_coal_states(tree, times))
        nbranches, nrecombs, ncoals = argweaver.get_nlineages_recomb_coal(
            tree, times)
        probs = [
            nbranches[i] * time_steps[i] for i in range(root_age_index + 1)
        ]
        recomb_time_index = stats.sample(probs)
        recomb_time = times[recomb_time_index]

        # choose branch for recombination
        branches = [
            x for x in states
            if x[1] == recomb_time_index and x[0] != tree.root.name
        ]
        recomb_node = tree[random.sample(branches, 1)[0][0]]

        # choose coal time
        j = recomb_time_index
        last_kj = nbranches[max(j - 1, 0)]
        while j < ntimes - 1:
            kj = nbranches[j]
            if ((recomb_node.name, j) in states
                    and recomb_node.parents[0].age > times[j]):
                kj -= 1
            assert kj > 0, (j, root_age_index, states)

            A = (times2[2 * j + 1] - times2[2 * j]) * kj
            if j > recomb_time_index:
                A += (times2[2 * j] - times2[2 * j - 1]) * last_kj
            coal_prob = 1.0 - exp(-A / float(popsizes[j]))
            if random.random() < coal_prob:
                break
            j += 1
            last_kj = kj
        coal_time_index = j
        coal_time = times[j]

        # choose coal node
        # since coal points collapse, exclude parent node, but allow sibling
        exclude = []

        def walk(node):
            exclude.append(node.name)
            if node.age == coal_time:
                for child in node.children:
                    walk(child)

        walk(recomb_node)
        exclude2 = (recomb_node.parents[0].name,
                    times.index(recomb_node.parents[0].age))
        branches = [
            x for x in states if x[1] == coal_time_index
            and x[0] not in exclude and x != exclude2
        ]
        coal_node = tree[random.sample(branches, 1)[0][0]]

        # yield SPR
        rleaves = list(tree.leaf_names(recomb_node))
        cleaves = list(tree.leaf_names(coal_node))

        yield pos, (rleaves, recomb_time), (cleaves, coal_time)

        # apply SPR to local tree
        broken = recomb_node.parents[0]
        recoal = tree.new_node(age=coal_time,
                               children=[recomb_node, coal_node])

        # add recoal node to tree
        recomb_node.parents[0] = recoal
        broken.children.remove(recomb_node)
        if coal_node.parents:
            recoal.parents.append(coal_node.parents[0])
            util.replace(coal_node.parents[0].children, coal_node, recoal)
            coal_node.parents[0] = recoal
        else:
            coal_node.parents.append(recoal)

        # remove broken node
        broken_child = broken.children[0]
        if broken.parents:
            broken_child.parents[0] = broken.parents[0]
            util.replace(broken.parents[0].children, broken, broken_child)
        else:
            broken_child.parents.remove(broken)

        del tree.nodes[broken.name]
        tree.set_root()
Exemplo n.º 2
0
def sample_dsmc_sprs(
        k, popsize, rho, recombmap=None, start=0.0, end=0.0, times=None,
        times2=None, init_tree=None, names=None, make_names=True):
    """
    Sample ARG using Discrete Sequentially Markovian Coalescent (SMC)

    k          -- chromosomes
    popsize    -- effective population size (haploid)
    rho        -- recombination rate (recombinations / site / generation)
    recombmap  -- map for variable recombination rate
    start      -- staring chromosome coordinate
    end        -- ending chromsome coordinate
    t          -- initial time (default: 0)
    names      -- names to use for leaves (default: None)
    make_names -- make names using strings (default: True)
    """

    assert times is not None
    assert times2 is not None
    ntimes = len(times) - 1
    time_steps = [times[i] - times[i-1] for i in range(1, ntimes+1)]
#    times2 = get_coal_times(times)

    if hasattr(popsize, "__len__"):
        popsizes = popsize
    else:
        popsizes = [popsize] * len(time_steps)


    # yield initial tree first
    if init_tree is None:
        init_tree = sample_tree(k, popsizes, times, start=start, end=end,
                                names=names, make_names=make_names)
        argweaver.discretize_arg(init_tree, times2)
    yield init_tree

    # sample SPRs
    pos = start
    tree = init_tree.copy()
    while True:
        # sample next recomb point
        treelen = sum(x.get_dist() for x in tree)
        blocklen = int(sample_next_recomb(treelen, rho, pos=pos,
                                          recombmap=recombmap, minlen=1))
        pos += blocklen
        if pos >= end - 1:
            break

        root_age_index = times.index(tree.root.age)

        # choose time interval for recombination
        states = set(argweaver.iter_coal_states(tree, times))
        nbranches, nrecombs, ncoals = argweaver.get_nlineages_recomb_coal(
            tree, times)
        probs = [nbranches[i] * time_steps[i]
                 for i in range(root_age_index+1)]
        recomb_time_index = stats.sample(probs)
        recomb_time = times[recomb_time_index]

        # choose branch for recombination
        branches = [x for x in states if x[1] == recomb_time_index and
                    x[0] != tree.root.name]
        recomb_node = tree[random.sample(branches, 1)[0][0]]

        # choose coal time
        j = recomb_time_index
        last_kj = nbranches[max(j-1, 0)]
        while j < ntimes - 1:
            kj = nbranches[j]
            if ((recomb_node.name, j) in states and
                    recomb_node.parents[0].age > times[j]):
                kj -= 1
            assert kj > 0, (j, root_age_index, states)

            A = (times2[2*j+1] - times2[2*j]) * kj
            if j > recomb_time_index:
                A += (times2[2*j] - times2[2*j-1]) * last_kj
            coal_prob = 1.0 - exp(-A/float(popsizes[j]))
            if random.random() < coal_prob:
                break
            j += 1
            last_kj = kj
        coal_time_index = j
        coal_time = times[j]

        # choose coal node
        # since coal points collapse, exclude parent node, but allow sibling
        exclude = []

        def walk(node):
            exclude.append(node.name)
            if node.age == coal_time:
                for child in node.children:
                    walk(child)

        walk(recomb_node)
        exclude2 = (recomb_node.parents[0].name,
                    times.index(recomb_node.parents[0].age))
        branches = [x for x in states if x[1] == coal_time_index and
                    x[0] not in exclude and x != exclude2]
        coal_node = tree[random.sample(branches, 1)[0][0]]

        # yield SPR
        rleaves = list(tree.leaf_names(recomb_node))
        cleaves = list(tree.leaf_names(coal_node))

        yield pos, (rleaves, recomb_time), (cleaves, coal_time)

        # apply SPR to local tree
        broken = recomb_node.parents[0]
        recoal = tree.new_node(age=coal_time,
                               children=[recomb_node, coal_node])

        # add recoal node to tree
        recomb_node.parents[0] = recoal
        broken.children.remove(recomb_node)
        if coal_node.parents:
            recoal.parents.append(coal_node.parents[0])
            util.replace(coal_node.parents[0].children, coal_node, recoal)
            coal_node.parents[0] = recoal
        else:
            coal_node.parents.append(recoal)

        # remove broken node
        broken_child = broken.children[0]
        if broken.parents:
            broken_child.parents[0] = broken.parents[0]
            util.replace(broken.parents[0].children, broken, broken_child)
        else:
            broken_child.parents.remove(broken)

        del tree.nodes[broken.name]
        tree.set_root()
Exemplo n.º 3
0
def test_trans_two():
    """
    Calculate transition probabilities for k=2

    Only calculate a single matrix
    """

    k = 2
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=5, maxtime=200000)
    time_steps = [times[i] - times[i - 1] for i in range(1, len(times))]
    time_steps.append(200000 * 10000.0)
    popsizes = [n] * len(times)

    arg = arglib.sample_arg(k, 2 * n, rho, start=0, end=length)

    argweaver.discretize_arg(arg, times)
    print "recomb", arglib.get_recomb_pos(arg)

    arg = argweaver.make_trunk_arg(0, length, "n0")

    pos = 10
    tree = arg.get_marginal_tree(pos)
    nlineages = argweaver.get_nlineages_recomb_coal(tree, times)
    states = list(argweaver.iter_coal_states(tree, times))
    mat = argweaver.calc_transition_probs(tree, states, nlineages, times,
                                          time_steps, popsizes, rho)

    nstates = len(states)

    def coal(j):
        return 1.0 - exp(-time_steps[j] / (2.0 * n))

    def recoal2(k, j):
        p = coal(j)
        for m in range(k, j):
            p *= 1.0 - coal(m)
        return p

    def recoal(k, j):
        if j == nstates - 1:
            return exp(-sum(time_steps[m] / (2.0 * n) for m in range(k, j)))
        else:
            return ((1.0 - exp(-time_steps[j] / (2.0 * n))) *
                    exp(-sum(time_steps[m] / (2.0 * n) for m in range(k, j))))

    def isrecomb(i):
        return 1.0 - exp(-max(rho * 2.0 * times[i], rho))

    def recomb(i, k):
        treelen = 2 * times[i] + time_steps[i]
        if k < i:
            return 2.0 * time_steps[k] / treelen / 2.0
        else:
            return time_steps[k] / treelen / 2.0

    def trans(i, j):
        a = states[i][1]
        b = states[j][1]

        p = sum(recoal(k, b) * recomb(a, k) for k in range(0, min(a, b) + 1))
        p += sum(recoal(k, b) * recomb(a, k) for k in range(0, min(a, b) + 1))
        p *= isrecomb(a)
        if i == j:
            p += 1.0 - isrecomb(a)
        return p

    for i in range(len(states)):
        for j in range(len(states)):
            print isrecomb(states[i][1])
            print states[i], states[j], mat[i][j], log(trans(i, j))
            fequal(mat[i][j], log(trans(i, j)))

        # recombs add up to 1
        fequal(sum(recomb(i, k) for k in range(i + 1)), 0.5)

        # recoal add up to 1
        fequal(sum(recoal(i, j) for j in range(i, nstates)), 1.0)

        # recomb * recoal add up to .5
        fequal(
            sum(
                sum(
                    recoal(k, j) * recomb(i, k)
                    for k in range(0,
                                   min(i, j) + 1))
                for j in range(0, nstates)), 0.5)

        fequal(sum(trans(i, j) for j in range(len(states))), 1.0)
Exemplo n.º 4
0
def test_trans_two():
    """
    Calculate transition probabilities for k=2

    Only calculate a single matrix
    """

    k = 2
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=5, maxtime=200000)
    time_steps = [times[i] - times[i-1]
                  for i in range(1, len(times))]
    time_steps.append(200000*10000.0)
    popsizes = [n] * len(times)

    arg = arglib.sample_arg(k, 2*n, rho, start=0, end=length)

    argweaver.discretize_arg(arg, times)
    print "recomb", arglib.get_recombs(arg)

    arg = argweaver.make_trunk_arg(0, length, "n0")

    pos = 10
    tree = arg.get_marginal_tree(pos)
    nlineages = argweaver.get_nlineages_recomb_coal(tree, times)
    states = list(argweaver.iter_coal_states(tree, times))
    mat = argweaver.calc_transition_probs(
        tree, states, nlineages,
        times, time_steps, popsizes, rho)

    nstates = len(states)

    def coal(j):
        return 1.0 - exp(-time_steps[j]/(2.0 * n))

    def recoal2(k, j):
        p = coal(j)
        for m in range(k, j):
            p *= 1.0 - coal(m)
        return p

    def recoal(k, j):
        if j == nstates-1:
            return exp(- sum(time_steps[m] / (2.0 * n)
                             for m in range(k, j)))
        else:
            return ((1.0 - exp(-time_steps[j]/(2.0 * n))) *
                    exp(- sum(time_steps[m] / (2.0 * n)
                              for m in range(k, j))))

    def isrecomb(i):
        return 1.0 - exp(-max(rho * 2.0 * times[i], rho))

    def recomb(i, k):
        treelen = 2*times[i] + time_steps[i]
        if k < i:
            return 2.0 * time_steps[k] / treelen / 2.0
        else:
            return time_steps[k] / treelen / 2.0

    def trans(i, j):
        a = states[i][1]
        b = states[j][1]

        p = sum(recoal(k, b) * recomb(a, k)
                for k in range(0, min(a, b)+1))
        p += sum(recoal(k, b) * recomb(a, k)
                 for k in range(0, min(a, b)+1))
        p *= isrecomb(a)
        if i == j:
            p += 1.0 - isrecomb(a)
        return p

    for i in range(len(states)):
        for j in range(len(states)):
            print isrecomb(states[i][1])
            print states[i], states[j], mat[i][j], log(trans(i, j))
            fequal(mat[i][j], log(trans(i, j)))

        # recombs add up to 1
        fequal(sum(recomb(i, k) for k in range(i+1)), 0.5)

        # recoal add up to 1
        fequal(sum(recoal(i, j) for j in range(i, nstates)), 1.0)

        # recomb * recoal add up to .5
        fequal(sum(sum(recoal(k, j) * recomb(i, k)
                       for k in range(0, min(i, j)+1))
                   for j in range(0, nstates)), 0.5)

        fequal(sum(trans(i, j) for j in range(len(states))), 1.0)