def test_two_populations_migration(self): n = 10 seed = 1234 ts1 = msprime.simulate( population_configurations=[ msprime.PopulationConfiguration(n), msprime.PopulationConfiguration(0), ], migration_matrix=[[0, 1], [1, 0]], random_seed=seed, ) tables = msprime.TableCollection(1) tables.populations.add_row() tables.populations.add_row() for _ in range(n): tables.nodes.add_row(flags=msprime.NODE_IS_SAMPLE, time=0, population=0) ts2 = msprime.simulate( from_ts=tables.tree_sequence(), start_time=0, population_configurations=[ msprime.PopulationConfiguration(), msprime.PopulationConfiguration(), ], migration_matrix=[[0, 1], [1, 0]], random_seed=seed, ) tables1 = ts1.dump_tables() tables2 = ts2.dump_tables() tables1.provenances.clear() tables2.provenances.clear() self.assertEqual(tables1, tables2)
def wright_fisher(N, delta, L, T): """ Direct implementation of Algorithm W. """ tables = msprime.TableCollection(L) tau = [] P = [j for j in range(N)] for j in range(N): tau.append(T) t = T n = N while t > 0: t -= 1 j = 0 Pp = [P[j] for j in range(N)] while j < N: if random.random() < delta: Pp[j] = n tau.append(t) a = random.randint(0, N - 1) b = random.randint(0, N - 1) x = random.uniform(0, L) tables.edges.add_row(0, x, P[a], n) tables.edges.add_row(x, L, P[b], n) n += 1 j += 1 P = Pp P = set(P) for j in range(n): tables.nodes.add_row(time=tau[j], flags=int(j in P)) tables.sort() return tables.tree_sequence()
def verify_simple_model(self, n, seed=1, recombination_rate=None, length=None, recombination_map=None): ts1 = msprime.simulate( n, random_seed=seed, recombination_rate=recombination_rate, length=length, recombination_map=recombination_map, model=self.model, ) tables = msprime.TableCollection(ts1.sequence_length) tables.populations.add_row() for _ in range(n): tables.nodes.add_row(flags=msprime.NODE_IS_SAMPLE, time=0, population=0) ts2 = msprime.simulate( from_ts=tables.tree_sequence(), start_time=0, random_seed=seed, recombination_rate=recombination_rate, recombination_map=recombination_map, model=self.model, ) tables1 = ts1.dump_tables() tables2 = ts2.dump_tables() tables1.provenances.clear() tables2.provenances.clear() self.assertEqual(tables1, tables2)
def run(self, ngens): tables = msprime.TableCollection() if self.deep_history: # initial population init_ts = msprime.simulate(self.N, recombination_rate=1.0) init_tables = init_ts.dump_tables() tables.nodes.set_columns(time=init_tables.nodes.time + ngens, flags=init_tables.nodes.flags) tables.edges.set_columns(left=init_tables.edges.left, right=init_tables.edges.right, parent=init_tables.edges.parent, child=init_tables.edges.child) else: for _ in range(self.N): tables.nodes.add_row(time=ngens) pop = list(range(self.N)) for t in range(ngens - 1, -1, -1): if self.debug: print("t:", t) print("pop:", pop) dead = [random.random() > self.survival for k in pop] # sample these first so that all parents are from the previous gen new_parents = [(random.choice(pop), random.choice(pop)) for k in range(sum(dead))] k = 0 if self.debug: print("Replacing", sum(dead), "individuals.") for j in range(self.N): if dead[j]: # this is: offspring ID, lparent, rparent, breakpoint offspring = len(tables.nodes) tables.nodes.add_row(time=t) lparent, rparent = new_parents[k] k += 1 bp = self.random_breakpoint() if self.debug: print("--->", offspring, lparent, rparent, bp) pop[j] = offspring if bp > 0.0: tables.edges.add_row(left=0.0, right=bp, parent=lparent, child=offspring) if bp < 1.0: tables.edges.add_row(left=bp, right=1.0, parent=rparent, child=offspring) if self.debug: print("Done! Final pop:") print(pop) flags = [(msprime.NODE_IS_SAMPLE if u in pop else 0) for u in range(len(tables.nodes))] tables.nodes.set_columns(time=tables.nodes.time, flags=flags) return tables
def wright_fisher(N, T, L=100, random_seed=None): """ Simulate a Wright-Fisher population of N haploid individuals with L discrete loci for T generations. Based on Algorithm W from https://www.biorxiv.org/content/biorxiv/early/2018/01/16/248500.full.pdf """ random.seed(random_seed) tables = msprime.TableCollection(L) P = np.arange(N, dtype=int) # Mark the initial generation as samples so that we remember these nodes. for j in range(N): tables.nodes.add_row(time=T, flags=msprime.NODE_IS_SAMPLE) t = T while t > 0: t -= 1 Pp = P.copy() for j in range(N): u = tables.nodes.add_row(time=t, flags=0) Pp[j] = u a = random.randint(0, N - 1) b = random.randint(0, N - 1) x = random.randint(1, L - 1) tables.edges.add_row(0, x, P[a], u) tables.edges.add_row(x, L, P[b], u) P = Pp # Now do some table manipulations to ensure that the tree sequence # that we output has the form that msprime needs to finish the # simulation. Much of the complexity here is caused by the tables API # not allowing direct access to memory, which will change soon. # Mark the extant population as samples also flags = tables.nodes.flags flags[P] = msprime.NODE_IS_SAMPLE tables.nodes.set_columns(flags=flags, time=tables.nodes.time) tables.sort() # Simplify with respect to the current generation, but ensuring we keep the # ancient nodes from the initial population. tables.simplify() # Unmark the initial generation as samples flags = tables.nodes.flags time = tables.nodes.time flags[:] = 0 flags[time == 0] = msprime.NODE_IS_SAMPLE # The final tables must also have at least one population which # the samples are assigned to tables.populations.add_row() tables.nodes.set_columns(flags=flags, time=time, population=np.zeros_like(tables.nodes.population)) return tables.tree_sequence()
def test_stick_tree(self): tables = msprime.TableCollection(1.0) tables.nodes.add_row(flags=msprime.NODE_IS_SAMPLE, time=0) tables.nodes.add_row(flags=0, time=1) tables.nodes.add_row(flags=0, time=2) tables.edges.add_row(0, 1, 1, 0) tables.edges.add_row(0, 1, 2, 1) ts = tables.tree_sequence() tsm = msprime.mutate(ts, rate=100, end_time=1, random_seed=1) self.assertGreater(tsm.num_sites, 0) self.assertTrue(all(mut.node == 0 for mut in ts.mutations())) tsm = msprime.mutate(ts, rate=100, start_time=0, end_time=1, random_seed=1) self.assertGreater(tsm.num_sites, 0) self.assertTrue(all(mut.node == 0 for mut in ts.mutations())) tsm = msprime.mutate(ts, rate=100, start_time=0.5, end_time=1, random_seed=1) self.assertGreater(tsm.num_sites, 0) self.assertTrue(all(mut.node == 0 for mut in ts.mutations())) tsm = msprime.mutate(ts, rate=100, start_time=1, random_seed=1) self.assertGreater(tsm.num_sites, 0) self.assertTrue(all(mut.node == 1 for mut in ts.mutations())) tsm = msprime.mutate(ts, rate=100, start_time=1, end_time=2, random_seed=1) self.assertGreater(tsm.num_sites, 0) self.assertTrue(all(mut.node == 1 for mut in ts.mutations())) tsm = msprime.mutate(ts, rate=100, start_time=1.5, end_time=2, random_seed=1) self.assertGreater(tsm.num_sites, 0) self.assertTrue(all(mut.node == 0 for mut in ts.mutations()))
def to_tree_sequence(self, simplify=True): tables = msprime.TableCollection(1) coal_depth = self.get_node_attributes('time') msprime_id = {} for proband in self.probands(): u = tables.nodes.add_row( time=coal_depth[proband], flags=msprime.NODE_IS_SAMPLE, ) msprime_id[proband] = u founders = self.founders() visited_nodes = set() for node, parent in self.iter_edges(forward=False): if parent not in visited_nodes: t = coal_depth[parent] f = msprime.NODE_IS_SAMPLE if node in founders else 0 u = tables.nodes.add_row( time=t, flags=f, ) msprime_id[parent] = u visited_nodes.add(parent) else: u = msprime_id[parent] tables.edges.add_row(0, 1, u, msprime_id[node]) tables.sort() if simplify: tables.simplify() # Unmark the initial generation as samples flags = tables.nodes.flags time = tables.nodes.time flags[:] = 0 flags[time == 0] = msprime.NODE_IS_SAMPLE # The final tables must also have at least one population which # the samples are assigned to tables.populations.add_row() tables.nodes.set_columns(flags=flags, time=time, population=np.zeros_like( tables.nodes.population)) return tables.tree_sequence(), msprime_id
def __init__(self, ts, sample, filter_zero_mutation_sites=True): self.ts = ts self.n = len(sample) self.sequence_length = ts.sequence_length self.filter_zero_mutation_sites = filter_zero_mutation_sites self.num_mutations = ts.num_mutations self.input_sites = list(ts.sites()) self.A_head = [None for _ in range(ts.num_nodes)] self.A_tail = [None for _ in range(ts.num_nodes)] self.tables = msprime.TableCollection( sequence_length=ts.sequence_length) # We don't touch populations, so add them straigtht in. t = ts.tables self.tables.populations.set_columns( metadata=t.populations.metadata, metadata_offset=t.populations.metadata_offset) # For now we don't remove individuals that have no nodes referring to # them, so we can just copy the table. self.tables.individuals.set_columns( flags=t.individuals.flags, location=t.individuals.location, location_offset=t.individuals.location_offset, metadata=t.individuals.metadata, metadata_offset=t.individuals.metadata_offset) self.edge_buffer = {} self.node_id_map = np.zeros(ts.num_nodes, dtype=np.int32) - 1 self.mutation_node_map = [-1 for _ in range(self.num_mutations)] self.samples = set(sample) for sample_id in sample: output_id = self.record_node(sample_id, is_sample=True) self.add_ancestry(sample_id, 0, self.sequence_length, output_id) # We keep a map of input nodes to mutations. self.mutation_map = [[] for _ in range(ts.num_nodes)] position = ts.tables.sites.position site = ts.tables.mutations.site node = ts.tables.mutations.node for mutation_id in range(ts.num_mutations): site_position = position[site[mutation_id]] self.mutation_map[node[mutation_id]].append( (site_position, mutation_id))
def run(self, ngens): nodes = msprime.NodeTable() edges = msprime.EdgeTable() migrations = msprime.MigrationTable() sites = msprime.SiteTable() mutations = msprime.MutationTable() provenances = msprime.ProvenanceTable() if self.deep_history: # initial population init_ts = msprime.simulate(self.N, recombination_rate=1.0) init_ts.dump_tables(nodes=nodes, edges=edges) nodes.set_columns(time=nodes.time + ngens, flags=nodes.flags) else: for _ in range(self.N): nodes.add_row(time=ngens) pop = list(range(self.N)) for t in range(ngens - 1, -1, -1): if self.debug: print("t:", t) print("pop:", pop) dead = [random.random() > self.survival for k in pop] # sample these first so that all parents are from the previous gen new_parents = [(random.choice(pop), random.choice(pop)) for k in range(sum(dead))] k = 0 if self.debug: print("Replacing", sum(dead), "individuals.") for j in range(self.N): if dead[j]: # this is: offspring ID, lparent, rparent, breakpoint offspring = nodes.num_rows nodes.add_row(time=t) lparent, rparent = new_parents[k] k += 1 bp = self.random_breakpoint() if self.debug: print("--->", offspring, lparent, rparent, bp) pop[j] = offspring if bp > 0.0: edges.add_row(left=0.0, right=bp, parent=lparent, child=offspring) if bp < 1.0: edges.add_row(left=bp, right=1.0, parent=rparent, child=offspring) if self.debug: print("Done! Final pop:") print(pop) flags = [(msprime.NODE_IS_SAMPLE if u in pop else 0) for u in range(nodes.num_rows)] nodes.set_columns(time=nodes.time, flags=flags) if self.debug: print("Done.") print("Nodes:") print(nodes) print("Edges:") print(edges) return msprime.TableCollection(nodes, edges, migrations, sites, mutations, provenances)
import msprime import numpy as np tables = msprime.TableCollection(1.0) tables.edges.add_row(left=0, right=1, parent=4, child=0) tables.edges.add_row(left=0, right=1, parent=4, child=1) tables.edges.add_row(left=0, right=1, parent=5, child=2) tables.edges.add_row(left=0, right=1, parent=5, child=3) tables.edges.add_row(left=0, right=1, parent=6, child=4) tables.edges.add_row(left=0, right=1, parent=6, child=5) time = np.array([0, 0, 0, 0, 1, 4, 10]) flags = [1] * len(time) tables.nodes.set_columns(time=time, flags=flags) tables.sites.set_columns(position=np.array([0.1, 0.2, 0.3, 0.4]), ancestral_state=['0'] * 4, ancestral_state_offset=[0] * 5) tables.mutations.add_row(site=0, node=4, derived_state='1') tables.mutations.add_row(site=1, node=5, derived_state='1') tables.mutations.add_row(site=2, node=0, derived_state='1') tables.mutations.add_row(site=3, node=5, derived_state='1') ts = tables.tree_sequence() t = next(ts.trees()) fig = t.draw(path="tree.svg", mutation_labels={
def simplify(S, Ni, Ei, L): """ This is an implementation of the simplify algorithm described in Appendix A of the paper. """ tables = msprime.TableCollection(L) No = tables.nodes Eo = tables.edges A = [[] for _ in range(len(Ni))] Q = [] for u in S: v = No.add_row(time=Ni.time[u], flags=1) A[u] = [Segment(0, L, v)] for u in range(len(Ni)): for e in [e for e in Ei if e.parent == u]: for x in A[e.child]: if x.right > e.left and e.right > x.left: y = Segment(max(x.left, e.left), min(x.right, e.right), x.node) heapq.heappush(Q, y) v = -1 while len(Q) > 0: l = Q[0].left r = L X = [] while len(Q) > 0 and Q[0].left == l: x = heapq.heappop(Q) X.append(x) r = min(r, x.right) if len(Q) > 0: r = min(r, Q[0].left) if len(X) == 1: x = X[0] alpha = x if len(Q) > 0 and Q[0].left < x.right: alpha = Segment(x.left, Q[0].left, x.node) x.left = Q[0].left heapq.heappush(Q, x) else: if v == -1: v = No.add_row(time=Ni.time[u]) alpha = Segment(l, r, v) for x in X: Eo.add_row(l, r, v, x.node) if x.right > r: x.left = r heapq.heappush(Q, x) A[u].append(alpha) # Sort the output edges and compact them as much as possible into # the output table. We skip this for the algorithm listing as it's pretty mundane. # Note: could be replaced with calls to squash_edges() and sort_tables() E = list(Eo) Eo.clear() E.sort(key=lambda e: (e.parent, e.child, e.right, e.left)) start = 0 for j in range(1, len(E)): condition = (E[j - 1].right != E[j].left or E[j - 1].parent != E[j].parent or E[j - 1].child != E[j].child) if condition: Eo.add_row(E[start].left, E[j - 1].right, E[j - 1].parent, E[j - 1].child) start = j j = len(E) Eo.add_row(E[start].left, E[j - 1].right, E[j - 1].parent, E[j - 1].child) return tables.tree_sequence()
def __init__(self, node_ids=None, nodes=None, edges=None, sites=None, mutations=None, migrations=None, ts=None, time=0.0, sequence_length=None, timings=None): """ The tables passed in define history before the simulation begins. If these are missing, then the input IDs specified in ``node_ids`` must be ``0...n-1``. :param dict node_ids: A dict indexed by input IDs so that ``node_ids[k]`` is the node ID of the node corresponding to sample ``k`` in the initial ``ts``. Must specify this for every individual that may be a parent moving forward. :param NodeTable nodes: A table describing prehistory of the simulation. :param EdgeTable edges: A table describing prehistory of the simulation. :param SiteTable sites: A table describing prehistory of the simulation. :param MutationTable mutations: A table describing prehistory of the simulation. :param MigrationTable migrations: A table describing prehistory of the simulation. :param TreeSequence ts: An alternative method to specifying past history. :param float time: The (forwards) time of the "present" at the start of the simulation. :param float sequence_length: The total length of the sequence (derived from input if not provided). :param ftprime.benchmarker.Timings timings: An object to record timing information. """ if timings is not None: self.timings = timings start = timer.process_time() else: self.timings = None # this is the largest (forwards) time seen so far self.max_time = time # T # dict of output node IDs indexed by input labels if node_ids is None: self.node_ids = {} else: self.node_ids = dict(node_ids) # the actual tables that get updated # DON'T actually store ts, just the tables: if ts is None: tables = msprime.TableCollection() for j, k in enumerate(sorted(self.node_ids.keys())): assert j == self.node_ids[k] tables.nodes.add_row(population=msprime.NULL_POPULATION, time=time) else: tables = ts.dump_tables() self.table_collection = tables self.nodes = tables.nodes self.edges = tables.edges self.sites = tables.sites self.mutations = tables.mutations self.migrations = tables.migrations if sequence_length is not None: if ts is not None: if sequence_length != ts.sequence_length: raise ValueError("Provided sequence_length does not match", "that of tree sequence ts.") self.sequence_length = sequence_length elif ts is not None: self.sequence_length = ts.sequence_length else: if edges.num_rows > 0: self.sequence_length = max(edges.right) else: raise ValueError("If prior history is not specified, sequence", "length must be provided.") # last (forwards) time we updated node times self.last_update_time = time # T_0 # number of nodes that have the time right self.last_update_node = self.nodes.num_rows # list of site positions, maintained as site tables don't have # efficient checking for membership self.site_positions = {p: k for k, p in enumerate(self.sites.position)} # for bookkeeping self.num_simplifies = 0 if self.timings is not None: self.timings.time_prepping += timer.process_time() - start