def __init__(self, ts, sample, ancestors): self.ts = ts self.samples = set(sample) assert (self.samples).issubset(set(range(0, ts.num_nodes))) self.ancestors = set(ancestors) assert (self.ancestors).issubset(set(range(0, ts.num_nodes))) self.table = tskit.EdgeTable() self.sequence_length = ts.sequence_length self.A_head = [None for _ in range(ts.num_nodes)] self.A_tail = [None for _ in range(ts.num_nodes)] for sample_id in sample: self.add_ancestry(0, self.sequence_length, sample_id, sample_id) self.edge_buffer = {}
def get_ancestry_table(ts, populations, samples=None, keep_ancestors=False): """ Returns an AncestryTable showing local ancestry information for the specified set of samples. :ivar ts: The tree sequence containing the dataset. :vartype ts: tskit.TreeSequence :ivar populations: A list of ancestral population IDs of interest. :vartype populations: list, dtype=int :ivar samples: A list of sample node IDs of interest. If None, all samples in the inputted tree sequence. :vartype samples: list, dtype=int :param bool keep_ancestors: If True, ancestral node IDs are retained in the output. :return: The ancestry table listing the local ancestry of the genomic segments corresponding to the child nodes. :rtype: :class:slime.AncestryTable """ # Extract ancestors with the given population labels. # TODO later: make this work with more flexible inputs. assert len(populations) > 0 if samples is None: samples = ts.samples() ancestors = [u.id for u in ts.nodes() if u.population in populations] if len(ancestors) == 0: raise ValueError("There are no nodes with the given population IDs.") # Apply map_ancestors. ancestor_table = ts.tables.map_ancestors(samples=samples, ancestors=ancestors) # Copy relevant edges into a new EdgeTable. edges_to_squash = tskit.EdgeTable() for row in ancestor_table: if row.child in samples: sample_pop = ts.tables.nodes.population[row.child] if sample_pop in populations: pop = sample_pop else: pop = ts.tables.nodes[row.parent].population edges_to_squash.add_row(left=row.left, right=row.right, parent=pop, child=row.child) # squash edges_to_squash.squash() # sort - would be better without df = pd.DataFrame( data={ 'left': edges_to_squash.left, 'right': edges_to_squash.right, 'population': edges_to_squash.parent, 'child': edges_to_squash.child }) df.sort_values(by=['child', 'left', 'right', 'population'], inplace=True) # Change into an ancestry table. ret = AncestryTable() ret.set_columns(left=np.array(df['left']), right=np.array(df['right']), population=np.array(df['population']), child=np.array(df['child'])) return (ret)
def reedgeucation(pstate, ischild, minparent, maxparent): """ Great function name, or best function name ever? """ E = 0 edges_new_births = tskit.EdgeTable() edges_previous_births = tskit.EdgeTable() for o in reversed(pstate.generation_offsets): for i in range(*o): # Get parent node IDs pnodes = pstate.pnodes[i] if i < len(pstate.parents): if minparent[pnodes[0]] == tskit.NULL: isparent0 = False else: isparent0 = True if minparent[pnodes[1]] == tskit.NULL: isparent1 = False else: isparent1 = True # mn0 = minparent[pnodes[0]] mx0 = maxparent[pnodes[0]] mn1 = minparent[pnodes[1]] mx1 = maxparent[pnodes[1]] if isparent0 is True and isparent1 is True: assert mx0 != tskit.NULL assert mx1 != tskit.NULL edges_previous_births.append_columns( pstate.tables.edges.left[E:mx0 + 1], pstate.tables.edges.right[E:mx0 + 1], pstate.tables.edges.parent[E:mx0 + 1], pstate.tables.edges.child[E:mx0 + 1], ) E = mx0 + 1 for k in pstate.buffered_edges[i][0]: assert k[2] == pnodes[0] edges_previous_births.add_row(*k) edges_previous_births.append_columns( pstate.tables.edges.left[E:mx1 + 1], pstate.tables.edges.right[E:mx1 + 1], pstate.tables.edges.parent[E:mx1 + 1], pstate.tables.edges.child[E:mx1 + 1], ) E = mx1 + 1 for k in pstate.buffered_edges[i][1]: assert k[2] == pnodes[1] edges_previous_births.add_row(*k) elif isparent0 is True: assert mx0 != tskit.NULL assert isparent1 is False edges_previous_births.append_columns( pstate.tables.edges.left[E:mx0 + 1], pstate.tables.edges.right[E:mx0 + 1], pstate.tables.edges.parent[E:mx0 + 1], pstate.tables.edges.child[E:mx0 + 1], ) E = mx0 + 1 for k in pstate.buffered_edges[i][0]: edges_previous_births.add_row(*k) for k in pstate.buffered_edges[i][1]: edges_previous_births.add_row(*k) elif isparent1 is True: assert mn1 != tskit.NULL assert mx1 != tskit.NULL assert isparent0 is False edges_previous_births.append_columns( pstate.tables.edges.left[E:mn1], pstate.tables.edges.right[E:mn1], pstate.tables.edges.parent[E:mn1], pstate.tables.edges.child[E:mn1], ) for k in pstate.buffered_edges[i][0]: edges_previous_births.add_row(*k) edges_previous_births.append_columns( pstate.tables.edges.left[mn1:mx1 + 1], pstate.tables.edges.right[mn1:mx1 + 1], pstate.tables.edges.parent[mn1:mx1 + 1], pstate.tables.edges.child[mn1:mx1 + 1], ) for k in pstate.buffered_edges[i][1]: edges_previous_births.add_row(*k) E = mx1 + 1 else: ptime = pstate.tables.nodes.time[pnodes[0]] if ischild[pnodes[0]] or ischild[pnodes[1]]: while (E < len(pstate.tables.edges) and pstate.tables.nodes.time[ pstate.tables.edges.parent[E]] < ptime): e = pstate.tables.edges[E] edges_previous_births.add_row( e.left, e.right, e.parent, e.child) E += 1 for n in [0, 1]: for k in pstate.buffered_edges[i][n]: assert k[2] == pnodes[n], f"{k} {pnodes}" edges_previous_births.add_row(*k) else: for n in [0, 1]: for k in pstate.buffered_edges[i][n]: assert k[2] == pnodes[n] edges_new_births.add_row(*k) while E < len(pstate.tables.edges): edges_previous_births.add_row( pstate.tables.edges[E].left, pstate.tables.edges[E].right, pstate.tables.edges[E].parent, pstate.tables.edges[E].child, ) E += 1 pstate.tables.edges.set_columns( edges_new_births.left, edges_new_births.right, edges_new_births.parent, edges_new_births.child, ) pstate.tables.edges.append_columns( edges_previous_births.left, edges_previous_births.right, edges_previous_births.parent, edges_previous_births.child, )
def stitch_tables( tables: tskit.TableCollection, buffered_edges: typing.List[BufferedEdgeList], alive_at_last_simplification: np.array, ): """ TODO: docstring w/details """ if len(tables.edges) == 0: # this is our first simplification for b in reversed(buffered_edges): for d in b.descendants: tables.edges.add_row(left=d.left, right=d.right, parent=b.parent, child=d.child) return tables # Get the time of the most recent node from alive_at_last_simplification # FIXME: this is better done by recording the last time of simplification, # passing that to here, and adding all elements whose parent times are # more recent time = -1 if len(alive_at_last_simplification) > 0: time = tables.nodes.time[alive_at_last_simplification].min() stitched_edges = tskit.EdgeTable() num_new_births = 0 for b in reversed(buffered_edges): if tables.nodes.time[b.parent] < time: for d in b.descendants: stitched_edges.add_row(left=d.left, right=d.right, parent=b.parent, child=d.child) num_new_births += 1 ( num_new_births_from_old_parents, old_edges_added, ) = handle_alive_nodes_from_last_time(tables, stitched_edges, alive_at_last_simplification, buffered_edges) tables.edges.set_columns( left=stitched_edges.left, right=stitched_edges.right, parent=stitched_edges.parent, child=stitched_edges.child, ) # Do some validation for i, e in enumerate(tables.edges): if i > 0: ti = tables.nodes.time[e.parent] tim1 = tables.nodes.time[tables.edges.parent[i - 1]] if not tim1 <= ti: with open("dump.txt", "w") as f: for j in range(i + 1): e = tables.edges[j] f.write(f"{e} {tables.nodes.time[e.parent]}\n") assert ( tim1 <= ti ), f"{tim1} {ti} {tables.edges.parent[i-1]} {e.parent} {tables.edges[i-1]} {tables.edges[i]}" last_child = np.array([-1] * len(tables.nodes), dtype=np.int32) for i, e in enumerate(tables.edges): if last_child[e.child] != -1: assert last_child[e.child] == i - 1 last_child[e.child] = i E = 0 while E < len(tables.edges): p = tables.edges.parent[E] children = [] while E < len(tables.edges) and tables.edges.parent[E] == p: children.append(tables.edges.child[E]) E += 1 assert children == sorted( children ), f"{children} {p in alive_at_last_simplification} {buffered_edges[p]}" return tables
# prev2 = np.where(pstate.tables.edges.parent == p2)[0] # if len(prev2) > 0: # print(p2, pstate.tables.edges.parent[prev2]) pstate.tables.nodes.set_columns( flags=flags, time=-1.0 * (pstate.tables.nodes.time - pstate.tables.nodes.time.max())) new_edges = 0 for e in pstate.buffered_edges: for j in e[0] + e[1]: new_edges += 1 assert len(pstate.tables.edges) + new_edges == tcopy_num_edges_b4_simplify temp_edges = tskit.EdgeTable() temp_edges_from_before = tskit.EdgeTable() new_edges2 = 0 edges_added = 0 E = 0 # NOTE: issue now is corner case of pre-existing edge # not in edge table, but then leaving descendants later for o in reversed(pstate.generation_offsets): print("range =", *o) for i in range(*o): # Fetch the parent node IDs pnodes = pstate.pnodes[i] if i < len(pwhere): pnodes = pstate.pnodes[i] where0 = pwhere[i][1]