Esempio n. 1
0
def get_split_sets(splits, tracks, breaks):
    """
    Group compatiable split into non-recombining blocks
    """

    blocks = []
    for b in breaks:
        blocks.append([])
    blocks.append([])

    seen = set()
    for split, track in zip(splits, tracks):
        if (split, track[0], track[1]) in seen:
            continue
        seen.add((split, track[0], track[1]))

        # find start and end blocks
        start_break = util.binsearch(breaks, track[0])[0]
        if start_break is None:
            start_block = 0
        else:
            start_block = start_break + 1

        end_break = util.binsearch(breaks, track[1])[1]
        if end_break is None:
            end_block = len(blocks) - 1
        else:
            end_block = end_break

        for i in range(start_block, end_block+1):
            blocks[i].append(split)

    return blocks
Esempio n. 2
0
def get_split_sets(splits, tracks, breaks):
    """
    Group compatiable split into non-recombining blocks
    """

    blocks = []
    for b in breaks:
        blocks.append([])
    blocks.append([])

    seen = set()
    for split, track in izip(splits, tracks):
        if (split, track[0], track[1]) in seen:
            continue
        seen.add((split, track[0], track[1]))

        # find start and end blocks
        start_break = util.binsearch(breaks, track[0])[0]
        if start_break is None:
            start_block = 0
        else:
            start_block = start_break + 1

        end_break = util.binsearch(breaks, track[1])[1]
        if end_break is None:
            end_block = len(blocks) - 1
        else:
            end_block = end_break

        for i in xrange(start_block, end_block + 1):
            blocks[i].append(split)

    return blocks
Esempio n. 3
0
def query_point_regions(point, regions, inc=True):

    ind = util.sortindex(regions, key=lambda r: r[1])
    regions_by_end = util.mget(regions, ind)

    end = util.binsearch([r[0] for r in regions], point)[1]
    start = util.binsearch([r[1] for r in regions_by_end], point)[0]

    if start is None:
        start = 0
    if end is None:
        end = len(regions)

    if inc:
        for i in xrange(start, end):
            if regions[i][0] <= point <= regions[i][1]:
                yield regions[i]
    else:
        for i in xrange(start, end):
            if regions[i][0] < point < regions[i][1]:
                yield regions[i]
def query_point_regions(point, regions, inc=True):

    ind = util.sortindex(regions, key=lambda r: r[1])
    rind = util.mget(range(len(regions)), ind)
    regions_by_end = util.mget(regions, ind)

    end = util.binsearch([r[0] for r in regions], x)[1]
    start = util.binsearch([r[1] for r in regions_by_end], x)[0]

    if start is None:
        start = 0
    if end is None:
        end = len(regions)

    if inc:
        for i in xrange(start, end):
            if regions[i][0] <= x <= regions[i][1]:
                yield regions[i]
    else:
        for i in xrange(start, end):
            if regions[i][0] < x < regions[i][1]:
                yield regions[i]
def find_region(regions, region):
    """Find a region in a sorted list of 'regions'"""
    low, ind = util.binsearch(regions, region.start-1, 
                                    lambda a,b: cmp(a.start, b))
    if ind == None:
        return None
    
    while ind < len(regions) and regions[ind] != region:
        ind += 1
    
    if ind == len(regions):
        return None
    else:
        return ind
Esempio n. 6
0
def find_region(regions, region):
    """Find a region in a sorted list of 'regions'"""
    low, ind = util.binsearch(regions, region.start - 1,
                              lambda a, b: cmp(a.start, b))
    if ind == None:
        return None

    while ind < len(regions) and regions[ind] != region:
        ind += 1

    if ind == len(regions):
        return None
    else:
        return ind
Esempio n. 7
0
def get_mutation_split_tracks(arg, mut_splits, mut_pos):

    mut_splits_set = set(mut_splits)
    mut_split_tracks = defaultdict(lambda: [])

    # find regions for splits
    i = 0
    for block, tree in zip(arglib.iter_recomb_blocks(arg),
                            arglib.iter_marginal_trees(arg)):
        for node in tree:
            if len(node.children) != 2 or node.children[0] == node.children[1]:
                continue
            split = tuple(sorted(tree.leaf_names(node)))
            if split in mut_splits_set:
                regions = mut_split_tracks[split]
                if len(regions) > 0 and regions[-1][1] == block[0]:
                    # extend region
                    regions[-1] = (regions[-1][0], block[1])
                else:
                    # add new region
                    regions.append(block)

    # keep only tracks who have a mutation in their interval
    mut_tracks = []
    for i in range(len(mut_pos)):
        for region in mut_split_tracks[mut_splits[i]]:
            if region[0] < mut_pos[i] < region[1]:
                a = util.binsearch(mut_pos, region[0])[0]
                a = a if a is not None else -.5
                b = util.binsearch(mut_pos, region[1])[1]
                b = b if b is not None else len(mut_pos)-.5
                mut_tracks.append((a, b))
                break
        else:
            assert False, i

    return mut_tracks
Esempio n. 8
0
def get_mutation_split_tracks(arg, mut_splits, mut_pos):

    mut_splits_set = set(mut_splits)
    mut_split_tracks = defaultdict(lambda: [])

    # find regions for splits
    i = 0
    for block, tree in izip(arglib.iter_recomb_blocks(arg), arglib.iter_marginal_trees(arg)):
        for node in tree:
            if len(node.children) != 2 or node.children[0] == node.children[1]:
                continue
            split = tuple(sorted(tree.leaf_names(node)))
            if split in mut_splits_set:
                regions = mut_split_tracks[split]
                if len(regions) > 0 and regions[-1][1] == block[0]:
                    # extend region
                    regions[-1] = (regions[-1][0], block[1])
                else:
                    # add new region
                    regions.append(block)

    # keep only tracks who have a mutation in their interval
    mut_tracks = []
    for i in xrange(len(mut_pos)):
        for region in mut_split_tracks[mut_splits[i]]:
            if region[0] < mut_pos[i] < region[1]:
                a = util.binsearch(mut_pos, region[0])[0]
                a = a if a is not None else -0.5
                b = util.binsearch(mut_pos, region[1])[1]
                b = b if b is not None else len(mut_pos) - 0.5
                mut_tracks.append((a, b))
                break
        else:
            assert False, i

    return mut_tracks
Esempio n. 9
0
def sample(weights):
    """
    Randomly choose an int between 0 and len(probs)-1 using
    the weights stored in list probs.
    
    item i will be chosen with probability weights[i]/sum(weights)
    """

    probs = util.one_norm(weights)

    cdf = [0]
    for i in range(1, len(probs)):
        cdf.append(cdf[-1] + probs[i - 1])

    pick = random.random()

    low, top = util.binsearch(cdf, pick)

    assert low != None

    return low
Esempio n. 10
0
def sample(weights):
    """
    Randomly choose an int between 0 and len(probs)-1 using
    the weights stored in list probs.
    
    item i will be chosen with probability weights[i]/sum(weights)
    """
    
    probs = util.one_norm(weights)
    
    cdf = [0]
    for i in range(1, len(probs)):
        cdf.append(cdf[-1] + probs[i-1])
    
    pick = random.random()
    
    low,top = util.binsearch(cdf, pick)
    
    assert low != None
    
    return low
Esempio n. 11
0
    def add_arg(self, arg):

        nleaves = len(list(arg.leaves()))
        times = self.times
        assert times
        eps = 1e-3

        def get_local_children(node, pos, local):
            return set(child for child in arg.get_local_children(node, pos) if child in local)

        def get_parent(node, pos, local):
            parent = arg.get_local_parent(node, pos)
            while len(get_local_children(parent, pos, local)) == 1:
                parent = arg.get_local_parent(parent, pos)
            return parent

        # add initial tree
        tree = arg.get_marginal_tree(arg.start)
        starts, ends, time_steps = count_tree_lineages(tree, times)
        self.init_trees.append({"starts": starts, "ends": ends, "time_steps": time_steps})

        # loop through sprs
        for recomb_pos, (rnode, rtime), (cnode, ctime), local in arglib.iter_arg_sprs(arg, use_local=True):
            i, _ = util.binsearch(times, ctime)
            self.ncoals[i] += 1

            recomb_node = arg[rnode]
            broken_node = get_parent(recomb_node, recomb_pos - eps, local)
            coals = [0.0] + [node.age for node in local if len(get_local_children(node, recomb_pos - eps, local)) == 2]

            coals.sort()
            nlineages = range(nleaves, 0, -1)
            assert len(nlineages) == len(coals)

            # subtract broken branch
            r = coals.index(recomb_node.age)
            r2 = coals.index(broken_node.age)
            for i in range(r, r2):
                nlineages[i] -= 1

            # get average number of branches in the time interval
            data = zip(coals, nlineages)
            for t in times[1:]:
                data.append((t, "time step"))
            data.sort()

            lineages_per_time = []
            counts = []
            last_lineages = 0
            last_time = 0.0
            for a, b in data:
                if b != "time step":
                    if a > last_time:
                        counts.append((last_lineages, a - last_time))
                    last_lineages = b
                else:
                    counts.append((last_lineages, a - last_time))
                    s = sum(u * v for u, v in counts)
                    total_time = sum(v for u, v in counts)
                    if s == 0.0:
                        lineages_per_time.append(last_lineages)
                    else:
                        lineages_per_time.append(s / total_time)
                    counts = []
                last_time = a

            assert len(lineages_per_time) == len(self.time_steps)

            r, _ = util.binsearch(times, rtime)
            c, _ = util.binsearch(times, ctime)
            for j in range(r, c):
                self.k_lineages[j] += lineages_per_time[j]
def find_region_pos(regions, pos):
    """Find the first region that starts after 'pos' in a sorted list of 'regions'"""
    low, top = util.binsearch(regions, pos-1, lambda a,b: cmp(a.start, b))
    return top
Esempio n. 13
0
def find_region_pos(regions, pos):
    """Find the first region that starts after 'pos' in a sorted list of 'regions'"""
    low, top = util.binsearch(regions, pos - 1, lambda a, b: cmp(a.start, b))
    return top
Esempio n. 14
0
    def add_arg(self, arg):

        nleaves = len(list(arg.leaves()))
        times = self.times
        assert times
        eps = 1e-3

        def get_local_children(node, pos, local):
            return set(child for child in arg.get_local_children(node, pos)
                       if child in local)

        def get_parent(node, pos, local):
            parent = arg.get_local_parent(node, pos)
            while len(get_local_children(parent, pos, local)) == 1:
                parent = arg.get_local_parent(parent, pos)
            return parent

        # add initial tree
        tree = arg.get_marginal_tree(arg.start)
        starts, ends, time_steps = count_tree_lineages(tree, times)
        self.init_trees.append({
            "starts": starts,
            "ends": ends,
            "time_steps": time_steps
        })

        # loop through sprs
        for recomb_pos, (rnode, rtime), (cnode, ctime), local in \
                arglib.iter_arg_sprs(arg, use_local=True):
            i, _ = util.binsearch(times, ctime)
            self.ncoals[i] += 1

            recomb_node = arg[rnode]
            broken_node = get_parent(recomb_node, recomb_pos - eps, local)
            coals = [0.0] + [
                node.age for node in local
                if len(get_local_children(node, recomb_pos - eps, local)) == 2
            ]

            coals.sort()
            nlineages = list(range(nleaves, 0, -1))
            assert len(nlineages) == len(coals)

            # subtract broken branch
            r = coals.index(recomb_node.age)
            r2 = coals.index(broken_node.age)
            for i in range(r, r2):
                nlineages[i] -= 1

            # get average number of branches in the time interval
            data = list(zip(coals, nlineages))
            for t in times[1:]:
                data.append((t, "time step"))
            data.sort()

            lineages_per_time = []
            counts = []
            last_lineages = 0
            last_time = 0.0
            for a, b in data:
                if b != "time step":
                    if a > last_time:
                        counts.append((last_lineages, a - last_time))
                    last_lineages = b
                else:
                    counts.append((last_lineages, a - last_time))
                    s = sum(u * v for u, v in counts)
                    total_time = sum(v for u, v in counts)
                    if s == 0.0:
                        lineages_per_time.append(last_lineages)
                    else:
                        lineages_per_time.append(s / total_time)
                    counts = []
                last_time = a

            assert len(lineages_per_time) == len(self.time_steps)

            r, _ = util.binsearch(times, rtime)
            c, _ = util.binsearch(times, ctime)
            for j in range(r, c):
                self.k_lineages[j] += lineages_per_time[j]