Ejemplo n.º 1
0
def sample_posterior(model, n, forward_probs=None, verbose=False):

    path = range(n)

    # get forward probabilities
    if forward_probs is None:
        forward_probs = forward_algorithm(model, n, verbose=verbose)

    # base case i=n-1
    B = 0.0
    i = n - 1
    A = [forward_probs[i][j] for j in range(model.get_num_states(i))]
    path[i] = j = stats.sample(map(exp, A))

    # recurse
    for i in xrange(n - 2, -1, -1):
        C = []
        A = []
        for j in range(model.get_num_states(i)):
            # !$A_{j,i} = F_{i,j} C_{i,j} B_{i+1,l}$!
            C.append(
                model.prob_transition(i, j, i + 1, path[i + 1]) +
                model.prob_emission(i + 1, path[i + 1]))
            A.append(forward_probs[i][j] + C[j] + B)
        path[i] = j = stats.sample(map(exp, A))
        # !$B_{i,j} = C_{i,j} B_{i+1,l}$!
        B += C[j]

    return path
def sample_posterior(model, n, forward_probs=None, verbose=False):

    path = range(n)

    # get forward probabilities
    if forward_probs is None:
        forward_probs = forward_algorithm(model, n, verbose=verbose)

    # base case i=n-1
    B = 0.0
    i = n-1
    A = [forward_probs[i][j] for j in range(model.get_num_states(i))]
    path[i] = j = stats.sample(map(exp, A))
  
    # recurse
    for i in xrange(n-2, -1, -1):
        C = []
        A = []
        for j in range(model.get_num_states(i)):
            # !$A_{j,i} = F_{i,j} C_{i,j} B_{i+1,l}$!
            C.append(
                model.prob_transition(i, j, i+1, path[i+1]) +
                model.prob_emission(i+1, path[i+1]))
            A.append(forward_probs[i][j] + C[j] + B)
        path[i] = j = stats.sample(map(exp, A))
        # !$B_{i,j} = C_{i,j} B_{i+1,l}$!
        B += C[j]
    
    return path
Ejemplo n.º 3
0
def proposeTreeWeighted(tree):
    """Nodes in tree must have logl in their data dict"""
    
    tree2 = tree.copy()
    
    # find edges for NNI
    nodes = tree2.nodes.values()
    nodes = filter(lambda x: not x.isLeaf() and 
                             x != tree2.root, nodes)
    edges = [(node, node.parent) for node in nodes]
    
    # create weights
    weights = []    
    for edge in edges:
        weights.append(edge[0].data["error"])
    
    if sum(weights) == 0:
        l = float(len(weights))
        weights = [1./l for x in weights]
    
    # sample by weight
    edge = edges[stats.sample(weights)]
    
    proposeNni(tree2, edge[0], edge[1], int(round(random.random())))
    return tree2
Ejemplo n.º 4
0
def proposeTree2(conf, tree,  distmat, labels, 
                  stree, gene2species, params, visited):
    tree2 = tree
    nniprobs = [.70, .20, .10]
    
    pick = stats.sample(nniprobs)
    
    for i in xrange(pick+1):
        tree2 = proposeTree(conf, tree2)

    if random.random() < conf["rerootprob"]:
        phylo.reconRoot(tree2, stree, gene2species, newCopy=False)
    return tree2
Ejemplo n.º 5
0
def propose_daughters(coal_tree, coal_recon, locus_tree, locus_events):

    lineages = coal.count_lineages_per_branch(coal_tree, coal_recon, locus_tree)
    daughters = set()

    for node, event in locus_events.iteritems():
        if event == "dup":
            # choose one of the children of node to be a daughter
            children = [child for child in node.children
                        if lineages[child][1] == 1]
            if len(children) > 0:
                daughters.add(children[stats.sample([1] * len(children))])

    return daughters
Ejemplo n.º 6
0
def propose_daughters(coal_tree, coal_recon, locus_tree, locus_events):

    lineages = coal.count_lineages_per_branch(coal_tree, coal_recon,
                                              locus_tree)
    daughters = set()

    for node, event in locus_events.iteritems():
        if event == "dup":
            # choose one of the children of node to be a daughter
            children = [
                child for child in node.children if lineages[child][1] == 1
            ]
            if len(children) > 0:
                daughters.add(children[stats.sample([1] * len(children))])

    return daughters
Ejemplo n.º 7
0
def sample_dsmc_sprs(k,
                     popsize,
                     rho,
                     recombmap=None,
                     start=0.0,
                     end=0.0,
                     times=None,
                     times2=None,
                     init_tree=None,
                     names=None,
                     make_names=True):
    """
    Sample ARG using Discrete Sequentially Markovian Coalescent (SMC)

    k          -- chromosomes
    popsize    -- effective population size (haploid)
    rho        -- recombination rate (recombinations / site / generation)
    recombmap  -- map for variable recombination rate
    start      -- staring chromosome coordinate
    end        -- ending chromsome coordinate
    t          -- initial time (default: 0)
    names      -- names to use for leaves (default: None)
    make_names -- make names using strings (default: True)
    """

    assert times is not None
    assert times2 is not None
    ntimes = len(times) - 1
    time_steps = [times[i] - times[i - 1] for i in range(1, ntimes + 1)]
    #    times2 = get_coal_times(times)

    if hasattr(popsize, "__len__"):
        popsizes = popsize
    else:
        popsizes = [popsize] * len(time_steps)

    # yield initial tree first
    if init_tree is None:
        init_tree = sample_tree(k,
                                popsizes,
                                times,
                                start=start,
                                end=end,
                                names=names,
                                make_names=make_names)
        argweaver.discretize_arg(init_tree, times2)
    yield init_tree

    # sample SPRs
    pos = start
    tree = init_tree.copy()
    while True:
        # sample next recomb point
        treelen = sum(x.get_dist() for x in tree)
        blocklen = int(
            sample_next_recomb(treelen,
                               rho,
                               pos=pos,
                               recombmap=recombmap,
                               minlen=1))
        pos += blocklen
        if pos >= end - 1:
            break

        root_age_index = times.index(tree.root.age)

        # choose time interval for recombination
        states = set(argweaver.iter_coal_states(tree, times))
        nbranches, nrecombs, ncoals = argweaver.get_nlineages_recomb_coal(
            tree, times)
        probs = [
            nbranches[i] * time_steps[i] for i in range(root_age_index + 1)
        ]
        recomb_time_index = stats.sample(probs)
        recomb_time = times[recomb_time_index]

        # choose branch for recombination
        branches = [
            x for x in states
            if x[1] == recomb_time_index and x[0] != tree.root.name
        ]
        recomb_node = tree[random.sample(branches, 1)[0][0]]

        # choose coal time
        j = recomb_time_index
        last_kj = nbranches[max(j - 1, 0)]
        while j < ntimes - 1:
            kj = nbranches[j]
            if ((recomb_node.name, j) in states
                    and recomb_node.parents[0].age > times[j]):
                kj -= 1
            assert kj > 0, (j, root_age_index, states)

            A = (times2[2 * j + 1] - times2[2 * j]) * kj
            if j > recomb_time_index:
                A += (times2[2 * j] - times2[2 * j - 1]) * last_kj
            coal_prob = 1.0 - exp(-A / float(popsizes[j]))
            if random.random() < coal_prob:
                break
            j += 1
            last_kj = kj
        coal_time_index = j
        coal_time = times[j]

        # choose coal node
        # since coal points collapse, exclude parent node, but allow sibling
        exclude = []

        def walk(node):
            exclude.append(node.name)
            if node.age == coal_time:
                for child in node.children:
                    walk(child)

        walk(recomb_node)
        exclude2 = (recomb_node.parents[0].name,
                    times.index(recomb_node.parents[0].age))
        branches = [
            x for x in states if x[1] == coal_time_index
            and x[0] not in exclude and x != exclude2
        ]
        coal_node = tree[random.sample(branches, 1)[0][0]]

        # yield SPR
        rleaves = list(tree.leaf_names(recomb_node))
        cleaves = list(tree.leaf_names(coal_node))

        yield pos, (rleaves, recomb_time), (cleaves, coal_time)

        # apply SPR to local tree
        broken = recomb_node.parents[0]
        recoal = tree.new_node(age=coal_time,
                               children=[recomb_node, coal_node])

        # add recoal node to tree
        recomb_node.parents[0] = recoal
        broken.children.remove(recomb_node)
        if coal_node.parents:
            recoal.parents.append(coal_node.parents[0])
            util.replace(coal_node.parents[0].children, coal_node, recoal)
            coal_node.parents[0] = recoal
        else:
            coal_node.parents.append(recoal)

        # remove broken node
        broken_child = broken.children[0]
        if broken.parents:
            broken_child.parents[0] = broken.parents[0]
            util.replace(broken.parents[0].children, broken, broken_child)
        else:
            broken_child.parents.remove(broken)

        del tree.nodes[broken.name]
        tree.set_root()
Ejemplo n.º 8
0
def proposeTree3(conf, tree,  distmat, labels, 
                  stree, gene2species, params, visited):
    toplogl = tree.data["logl"]
    toptree = tree.copy()
    
    tree = tree.copy()
    
    nodes = tree.nodes.values()
    nodes.remove(tree.root)
    weights = [1 for x in nodes] #[x.data["error"] for x in nodes]
    badgene = nodes[stats.sample(weights)]
    
    
    # detemine distance from badgene to everyone else
    dists = util.Dict(default=-util.INF)
    def walk(node, dist):
        dists[node.name] = dist
        for child in node.children:
            walk(child, dist + child.dist)
    walk(badgene, 0)
    seen = set([badgene])
    node = badgene.parent
    dist = badgene.dist
    while node != None:        
        for child in node.children:
            if child not in seen:
                walk(child, dist)
        seen.add(node)
        dist +=  node.dist
        node = node.parent
    
    tree1, tree2 = splitTree(tree, badgene, badgene.parent)
    
    names = tree1.nodes.keys()
    names.remove(tree1.root.name)
    #names.sort(key=lambda x: dists[x])
    random.shuffle(names)
    
    
    for name in names[:min(len(names), conf["regraftloop"])]:
        tree = tree1.copy()
        node = tree.nodes[name]
        
        #print "p3>>", node.name, node.parent.name
        regraftTree(tree, tree2.copy(), node, node.parent)
        
        thash = phylo.hash_tree(tree)
        
        if thash not in visited:        
            Spidir.setTreeDistances(conf, tree, distmat, labels)
            logl = Spidir.treeLogLikelihood(conf, tree, stree, gene2species, params)
        addVisited(conf, visited, tree, gene2species, thash)
        logl, tree, count = visited[thash]
        
        if logl > toplogl:
            toplogl = logl
            toptree = tree
            
            # try returning immediately
            #return toptree

    
    assert toptree != None
    
    return toptree
Ejemplo n.º 9
0
def sample_dsmc_sprs(
        k, popsize, rho, recombmap=None, start=0.0, end=0.0, times=None,
        times2=None, init_tree=None, names=None, make_names=True):
    """
    Sample ARG using Discrete Sequentially Markovian Coalescent (SMC)

    k          -- chromosomes
    popsize    -- effective population size (haploid)
    rho        -- recombination rate (recombinations / site / generation)
    recombmap  -- map for variable recombination rate
    start      -- staring chromosome coordinate
    end        -- ending chromsome coordinate
    t          -- initial time (default: 0)
    names      -- names to use for leaves (default: None)
    make_names -- make names using strings (default: True)
    """

    assert times is not None
    assert times2 is not None
    ntimes = len(times) - 1
    time_steps = [times[i] - times[i-1] for i in range(1, ntimes+1)]
#    times2 = get_coal_times(times)

    if hasattr(popsize, "__len__"):
        popsizes = popsize
    else:
        popsizes = [popsize] * len(time_steps)


    # yield initial tree first
    if init_tree is None:
        init_tree = sample_tree(k, popsizes, times, start=start, end=end,
                                names=names, make_names=make_names)
        argweaver.discretize_arg(init_tree, times2)
    yield init_tree

    # sample SPRs
    pos = start
    tree = init_tree.copy()
    while True:
        # sample next recomb point
        treelen = sum(x.get_dist() for x in tree)
        blocklen = int(sample_next_recomb(treelen, rho, pos=pos,
                                          recombmap=recombmap, minlen=1))
        pos += blocklen
        if pos >= end - 1:
            break

        root_age_index = times.index(tree.root.age)

        # choose time interval for recombination
        states = set(argweaver.iter_coal_states(tree, times))
        nbranches, nrecombs, ncoals = argweaver.get_nlineages_recomb_coal(
            tree, times)
        probs = [nbranches[i] * time_steps[i]
                 for i in range(root_age_index+1)]
        recomb_time_index = stats.sample(probs)
        recomb_time = times[recomb_time_index]

        # choose branch for recombination
        branches = [x for x in states if x[1] == recomb_time_index and
                    x[0] != tree.root.name]
        recomb_node = tree[random.sample(branches, 1)[0][0]]

        # choose coal time
        j = recomb_time_index
        last_kj = nbranches[max(j-1, 0)]
        while j < ntimes - 1:
            kj = nbranches[j]
            if ((recomb_node.name, j) in states and
                    recomb_node.parents[0].age > times[j]):
                kj -= 1
            assert kj > 0, (j, root_age_index, states)

            A = (times2[2*j+1] - times2[2*j]) * kj
            if j > recomb_time_index:
                A += (times2[2*j] - times2[2*j-1]) * last_kj
            coal_prob = 1.0 - exp(-A/float(popsizes[j]))
            if random.random() < coal_prob:
                break
            j += 1
            last_kj = kj
        coal_time_index = j
        coal_time = times[j]

        # choose coal node
        # since coal points collapse, exclude parent node, but allow sibling
        exclude = []

        def walk(node):
            exclude.append(node.name)
            if node.age == coal_time:
                for child in node.children:
                    walk(child)

        walk(recomb_node)
        exclude2 = (recomb_node.parents[0].name,
                    times.index(recomb_node.parents[0].age))
        branches = [x for x in states if x[1] == coal_time_index and
                    x[0] not in exclude and x != exclude2]
        coal_node = tree[random.sample(branches, 1)[0][0]]

        # yield SPR
        rleaves = list(tree.leaf_names(recomb_node))
        cleaves = list(tree.leaf_names(coal_node))

        yield pos, (rleaves, recomb_time), (cleaves, coal_time)

        # apply SPR to local tree
        broken = recomb_node.parents[0]
        recoal = tree.new_node(age=coal_time,
                               children=[recomb_node, coal_node])

        # add recoal node to tree
        recomb_node.parents[0] = recoal
        broken.children.remove(recomb_node)
        if coal_node.parents:
            recoal.parents.append(coal_node.parents[0])
            util.replace(coal_node.parents[0].children, coal_node, recoal)
            coal_node.parents[0] = recoal
        else:
            coal_node.parents.append(recoal)

        # remove broken node
        broken_child = broken.children[0]
        if broken.parents:
            broken_child.parents[0] = broken.parents[0]
            util.replace(broken.parents[0].children, broken, broken_child)
        else:
            broken_child.parents.remove(broken)

        del tree.nodes[broken.name]
        tree.set_root()