예제 #1
0
 def __init__(self, ts, sample, ancestors):
     self.ts = ts
     self.samples = set(sample)
     assert (self.samples).issubset(set(range(0, ts.num_nodes)))
     self.ancestors = set(ancestors)
     assert (self.ancestors).issubset(set(range(0, ts.num_nodes)))
     self.table = tskit.EdgeTable()
     self.sequence_length = ts.sequence_length
     self.A_head = [None for _ in range(ts.num_nodes)]
     self.A_tail = [None for _ in range(ts.num_nodes)]
     for sample_id in sample:
         self.add_ancestry(0, self.sequence_length, sample_id, sample_id)
     self.edge_buffer = {}
예제 #2
0
def get_ancestry_table(ts, populations, samples=None, keep_ancestors=False):
    """
    Returns an AncestryTable showing local ancestry information for the
    specified set of samples. 

    :ivar ts: The tree sequence containing the dataset.
    :vartype ts: tskit.TreeSequence
    :ivar populations: A list of ancestral population IDs of interest.
    :vartype populations: list, dtype=int
    :ivar samples: A list of sample node IDs of interest. If None, all samples in the inputted tree sequence.
    :vartype samples: list, dtype=int
    :param bool keep_ancestors: If True, ancestral node IDs are retained in the output.
    :return: The ancestry table listing the local ancestry of the genomic segments corresponding to the child nodes.
    :rtype: :class:slime.AncestryTable
    """
    # Extract ancestors with the given population labels.
    # TODO later: make this work with more flexible inputs.
    assert len(populations) > 0
    if samples is None:
        samples = ts.samples()
    ancestors = [u.id for u in ts.nodes() if u.population in populations]
    if len(ancestors) == 0:
        raise ValueError("There are no nodes with the given population IDs.")

    # Apply map_ancestors.
    ancestor_table = ts.tables.map_ancestors(samples=samples,
                                             ancestors=ancestors)

    # Copy relevant edges into a new EdgeTable.
    edges_to_squash = tskit.EdgeTable()
    for row in ancestor_table:
        if row.child in samples:
            sample_pop = ts.tables.nodes.population[row.child]
            if sample_pop in populations:
                pop = sample_pop
            else:
                pop = ts.tables.nodes[row.parent].population
            edges_to_squash.add_row(left=row.left,
                                    right=row.right,
                                    parent=pop,
                                    child=row.child)

    # squash
    edges_to_squash.squash()

    # sort - would be better without
    df = pd.DataFrame(
        data={
            'left': edges_to_squash.left,
            'right': edges_to_squash.right,
            'population': edges_to_squash.parent,
            'child': edges_to_squash.child
        })
    df.sort_values(by=['child', 'left', 'right', 'population'], inplace=True)

    # Change into an ancestry table.
    ret = AncestryTable()
    ret.set_columns(left=np.array(df['left']),
                    right=np.array(df['right']),
                    population=np.array(df['population']),
                    child=np.array(df['child']))

    return (ret)
def reedgeucation(pstate, ischild, minparent, maxparent):
    """
    Great function name, or best function name ever?
    """

    E = 0
    edges_new_births = tskit.EdgeTable()
    edges_previous_births = tskit.EdgeTable()
    for o in reversed(pstate.generation_offsets):
        for i in range(*o):
            # Get parent node IDs
            pnodes = pstate.pnodes[i]
            if i < len(pstate.parents):
                if minparent[pnodes[0]] == tskit.NULL:
                    isparent0 = False
                else:
                    isparent0 = True
                if minparent[pnodes[1]] == tskit.NULL:
                    isparent1 = False
                else:
                    isparent1 = True
                # mn0 = minparent[pnodes[0]]
                mx0 = maxparent[pnodes[0]]
                mn1 = minparent[pnodes[1]]
                mx1 = maxparent[pnodes[1]]
                if isparent0 is True and isparent1 is True:
                    assert mx0 != tskit.NULL
                    assert mx1 != tskit.NULL
                    edges_previous_births.append_columns(
                        pstate.tables.edges.left[E:mx0 + 1],
                        pstate.tables.edges.right[E:mx0 + 1],
                        pstate.tables.edges.parent[E:mx0 + 1],
                        pstate.tables.edges.child[E:mx0 + 1],
                    )
                    E = mx0 + 1
                    for k in pstate.buffered_edges[i][0]:
                        assert k[2] == pnodes[0]
                        edges_previous_births.add_row(*k)
                    edges_previous_births.append_columns(
                        pstate.tables.edges.left[E:mx1 + 1],
                        pstate.tables.edges.right[E:mx1 + 1],
                        pstate.tables.edges.parent[E:mx1 + 1],
                        pstate.tables.edges.child[E:mx1 + 1],
                    )
                    E = mx1 + 1
                    for k in pstate.buffered_edges[i][1]:
                        assert k[2] == pnodes[1]
                        edges_previous_births.add_row(*k)
                elif isparent0 is True:
                    assert mx0 != tskit.NULL
                    assert isparent1 is False
                    edges_previous_births.append_columns(
                        pstate.tables.edges.left[E:mx0 + 1],
                        pstate.tables.edges.right[E:mx0 + 1],
                        pstate.tables.edges.parent[E:mx0 + 1],
                        pstate.tables.edges.child[E:mx0 + 1],
                    )
                    E = mx0 + 1
                    for k in pstate.buffered_edges[i][0]:
                        edges_previous_births.add_row(*k)
                    for k in pstate.buffered_edges[i][1]:
                        edges_previous_births.add_row(*k)
                elif isparent1 is True:
                    assert mn1 != tskit.NULL
                    assert mx1 != tskit.NULL
                    assert isparent0 is False
                    edges_previous_births.append_columns(
                        pstate.tables.edges.left[E:mn1],
                        pstate.tables.edges.right[E:mn1],
                        pstate.tables.edges.parent[E:mn1],
                        pstate.tables.edges.child[E:mn1],
                    )
                    for k in pstate.buffered_edges[i][0]:
                        edges_previous_births.add_row(*k)
                    edges_previous_births.append_columns(
                        pstate.tables.edges.left[mn1:mx1 + 1],
                        pstate.tables.edges.right[mn1:mx1 + 1],
                        pstate.tables.edges.parent[mn1:mx1 + 1],
                        pstate.tables.edges.child[mn1:mx1 + 1],
                    )
                    for k in pstate.buffered_edges[i][1]:
                        edges_previous_births.add_row(*k)
                    E = mx1 + 1
                else:
                    ptime = pstate.tables.nodes.time[pnodes[0]]
                    if ischild[pnodes[0]] or ischild[pnodes[1]]:
                        while (E < len(pstate.tables.edges)
                               and pstate.tables.nodes.time[
                                   pstate.tables.edges.parent[E]] < ptime):
                            e = pstate.tables.edges[E]
                            edges_previous_births.add_row(
                                e.left, e.right, e.parent, e.child)
                            E += 1

                    for n in [0, 1]:
                        for k in pstate.buffered_edges[i][n]:
                            assert k[2] == pnodes[n], f"{k} {pnodes}"
                            edges_previous_births.add_row(*k)
            else:
                for n in [0, 1]:
                    for k in pstate.buffered_edges[i][n]:
                        assert k[2] == pnodes[n]
                        edges_new_births.add_row(*k)

    while E < len(pstate.tables.edges):
        edges_previous_births.add_row(
            pstate.tables.edges[E].left,
            pstate.tables.edges[E].right,
            pstate.tables.edges[E].parent,
            pstate.tables.edges[E].child,
        )
        E += 1
    pstate.tables.edges.set_columns(
        edges_new_births.left,
        edges_new_births.right,
        edges_new_births.parent,
        edges_new_births.child,
    )
    pstate.tables.edges.append_columns(
        edges_previous_births.left,
        edges_previous_births.right,
        edges_previous_births.parent,
        edges_previous_births.child,
    )
def stitch_tables(
    tables: tskit.TableCollection,
    buffered_edges: typing.List[BufferedEdgeList],
    alive_at_last_simplification: np.array,
):
    """
    TODO: docstring w/details
    """
    if len(tables.edges) == 0:
        # this is our first simplification
        for b in reversed(buffered_edges):
            for d in b.descendants:
                tables.edges.add_row(left=d.left,
                                     right=d.right,
                                     parent=b.parent,
                                     child=d.child)
        return tables

    # Get the time of the most recent node from alive_at_last_simplification
    # FIXME: this is better done by recording the last time of simplification,
    # passing that to here, and adding all elements whose parent times are
    # more recent
    time = -1
    if len(alive_at_last_simplification) > 0:
        time = tables.nodes.time[alive_at_last_simplification].min()
    stitched_edges = tskit.EdgeTable()
    num_new_births = 0
    for b in reversed(buffered_edges):
        if tables.nodes.time[b.parent] < time:
            for d in b.descendants:
                stitched_edges.add_row(left=d.left,
                                       right=d.right,
                                       parent=b.parent,
                                       child=d.child)
                num_new_births += 1

    (
        num_new_births_from_old_parents,
        old_edges_added,
    ) = handle_alive_nodes_from_last_time(tables, stitched_edges,
                                          alive_at_last_simplification,
                                          buffered_edges)

    tables.edges.set_columns(
        left=stitched_edges.left,
        right=stitched_edges.right,
        parent=stitched_edges.parent,
        child=stitched_edges.child,
    )

    # Do some validation
    for i, e in enumerate(tables.edges):
        if i > 0:
            ti = tables.nodes.time[e.parent]
            tim1 = tables.nodes.time[tables.edges.parent[i - 1]]
            if not tim1 <= ti:
                with open("dump.txt", "w") as f:
                    for j in range(i + 1):
                        e = tables.edges[j]
                        f.write(f"{e} {tables.nodes.time[e.parent]}\n")

            assert (
                tim1 <= ti
            ), f"{tim1} {ti} {tables.edges.parent[i-1]} {e.parent} {tables.edges[i-1]} {tables.edges[i]}"
    last_child = np.array([-1] * len(tables.nodes), dtype=np.int32)
    for i, e in enumerate(tables.edges):
        if last_child[e.child] != -1:
            assert last_child[e.child] == i - 1
            last_child[e.child] = i

    E = 0
    while E < len(tables.edges):
        p = tables.edges.parent[E]
        children = []
        while E < len(tables.edges) and tables.edges.parent[E] == p:
            children.append(tables.edges.child[E])
            E += 1
        assert children == sorted(
            children
        ), f"{children} {p in alive_at_last_simplification} {buffered_edges[p]}"

    return tables
#         prev2 = np.where(pstate.tables.edges.parent == p2)[0]
#         if len(prev2) > 0:
#             print(p2, pstate.tables.edges.parent[prev2])

pstate.tables.nodes.set_columns(
    flags=flags,
    time=-1.0 * (pstate.tables.nodes.time - pstate.tables.nodes.time.max()))

new_edges = 0
for e in pstate.buffered_edges:
    for j in e[0] + e[1]:
        new_edges += 1

assert len(pstate.tables.edges) + new_edges == tcopy_num_edges_b4_simplify

temp_edges = tskit.EdgeTable()
temp_edges_from_before = tskit.EdgeTable()

new_edges2 = 0
edges_added = 0
E = 0
# NOTE: issue now is corner case of pre-existing edge
# not in edge table, but then leaving descendants later
for o in reversed(pstate.generation_offsets):
    print("range =", *o)
    for i in range(*o):
        # Fetch the parent node IDs
        pnodes = pstate.pnodes[i]
        if i < len(pwhere):
            pnodes = pstate.pnodes[i]
            where0 = pwhere[i][1]