def __init__(self, ts, sample, filter_zero_mutation_sites=True): self.ts = ts self.n = len(sample) self.sequence_length = ts.sequence_length self.filter_zero_mutation_sites = filter_zero_mutation_sites self.num_mutations = ts.num_mutations self.input_sites = list(ts.sites()) self.A_head = [None for _ in range(ts.num_nodes)] self.A_tail = [None for _ in range(ts.num_nodes)] self.mutation_table = msprime.MutationTable(ts.num_mutations) self.node_table = msprime.NodeTable(ts.num_nodes) self.edge_table = msprime.EdgeTable(ts.num_edges) self.site_table = msprime.SiteTable(ts.num_sites) self.mutation_table = msprime.MutationTable(ts.num_mutations) self.edge_buffer = {} self.node_id_map = np.zeros(ts.num_nodes, dtype=np.int32) - 1 self.mutation_node_map = [-1 for _ in range(self.num_mutations)] self.samples = set(sample) for sample_id in sample: output_id = self.record_node(sample_id, is_sample=True) self.add_ancestry(sample_id, 0, self.sequence_length, output_id) # We keep a map of input nodes to mutations. self.mutation_map = [[] for _ in range(ts.num_nodes)] position = ts.tables.sites.position site = ts.tables.mutations.site node = ts.tables.mutations.node for mutation_id in range(ts.num_mutations): site_position = position[site[mutation_id]] self.mutation_map[node[mutation_id]].append( (site_position, mutation_id))
def __init__(self, ts, sample, filter_zero_mutation_sites=True): self.ts = ts self.n = len(sample) self.sequence_length = ts.sequence_length self.filter_zero_mutation_sites = filter_zero_mutation_sites self.num_mutations = ts.num_mutations self.input_sites = list(ts.sites()) # A maps input node IDs to the extant ancestor chain. Once the algorithm # has processed the ancestors, they are are removed from the map. self.A = {} self.mutation_table = msprime.MutationTable(ts.num_mutations) self.node_table = msprime.NodeTable(ts.num_nodes) self.edge_table = msprime.EdgeTable(ts.num_edges) self.site_table = msprime.SiteTable(ts.num_sites) self.mutation_table = msprime.MutationTable(ts.num_mutations) self.edge_buffer = [] self.node_id_map = {} self.mutation_node_map = [-1 for _ in range(self.num_mutations)] self.samples = set(sample) for sample_id in sample: self.insert_sample(sample_id) # We keep a map of input nodes to mutations. self.mutation_map = [[] for _ in range(ts.num_nodes)] position = ts.tables.sites.position site = ts.tables.mutations.site node = ts.tables.mutations.node for mutation_id in range(ts.num_mutations): site_position = position[site[mutation_id]] self.mutation_map[node[mutation_id]].append( (site_position, mutation_id))
def wright_fisher(N, delta, L, T): """ Direct implementation of Algorithm W. """ edges = msprime.EdgeTable() tau = [] P = [j for j in range(N)] for j in range(N): tau.append(T) t = T n = N while t > 0: t -= 1 j = 0 Pp = [P[j] for j in range(N)] while j < N: if random.random() < delta: Pp[j] = n tau.append(t) a = random.randint(0, N - 1) b = random.randint(0, N - 1) x = random.uniform(0, L) edges.add_row(0, x, P[a], n) edges.add_row(x, L, P[b], n) n += 1 j += 1 P = Pp nodes = msprime.NodeTable() P = set(P) for j in range(n): nodes.add_row(time=tau[j], flags=int(j in P)) msprime.sort_tables(nodes=nodes, edges=edges) return msprime.load_tables(nodes=nodes, edges=edges)
def _load_legacy_hdf5_v3(root, remove_duplicate_positions): # get the trees group for the records and samples trees_group = root["trees"] nodes_group = trees_group["nodes"] time = np.array(nodes_group["time"]) breakpoints = np.array(trees_group["breakpoints"]) records_group = trees_group["records"] left_indexes = np.array(records_group["left"]) right_indexes = np.array(records_group["right"]) record_node = np.array(records_group["node"], dtype=np.int32) num_nodes = time.shape[0] sample_size = np.min(record_node) flags = np.zeros(num_nodes, dtype=np.uint32) flags[:sample_size] = msprime.NODE_IS_SAMPLE children_length = np.array(records_group["num_children"], dtype=np.uint32) total_rows = np.sum(children_length) left = np.zeros(total_rows, dtype=np.float64) right = np.zeros(total_rows, dtype=np.float64) parent = np.zeros(total_rows, dtype=np.int32) record_left = breakpoints[left_indexes] record_right = breakpoints[right_indexes] k = 0 for j in range(left_indexes.shape[0]): for _ in range(children_length[j]): left[k] = record_left[j] right[k] = record_right[j] parent[k] = record_node[j] k += 1 nodes = msprime.NodeTable() nodes.set_columns(flags=flags, time=nodes_group["time"], population=nodes_group["population"]) edges = msprime.EdgeTable() edges.set_columns(left=left, right=right, parent=parent, child=records_group["children"]) sites = msprime.SiteTable() mutations = msprime.MutationTable() if "mutations" in root: _convert_hdf5_mutations(root["mutations"], sites, mutations, remove_duplicate_positions) old_timestamp = datetime.datetime.min.isoformat() provenances = msprime.ProvenanceTable() if "provenance" in root: for record in root["provenance"]: provenances.add_row(timestamp=old_timestamp, record=record) provenances.add_row(_get_upgrade_provenance(root)) msprime.sort_tables(nodes=nodes, edges=edges, sites=sites, mutations=mutations) return msprime.load_tables(nodes=nodes, edges=edges, sites=sites, mutations=mutations, provenances=provenances)
def store_output(self): if self.num_ancestors > 0: ts = self.get_tree_sequence(rescale_positions=False) else: # Allocate an empty tree sequence. ts = msprime.load_tables(nodes=msprime.NodeTable(), edges=msprime.EdgeTable(), sequence_length=1) if self.output_path is not None: ts.dump(self.output_path) return ts
def node_metadata_example(): ts = msprime.simulate( sample_size=100, recombination_rate=0.1, length=10, random_seed=1) nodes = msprime.NodeTable() edges = msprime.EdgeTable() ts.dump_tables(nodes=nodes, edges=edges) new_nodes = msprime.NodeTable() metadatas = ["n_{}".format(u) for u in range(ts.num_nodes)] packed, offset = msprime.pack_strings(metadatas) new_nodes.set_columns( metadata=packed, metadata_offset=offset, flags=nodes.flags, time=nodes.time) return msprime.load_tables(nodes=new_nodes, edges=edges)
def general_mutation_example(): ts = msprime.simulate(10, recombination_rate=1, length=10, random_seed=2) nodes = msprime.NodeTable() edges = msprime.EdgeTable() ts.dump_tables(nodes=nodes, edges=edges) sites = msprime.SiteTable() mutations = msprime.MutationTable() sites.add_row(position=0, ancestral_state="A", metadata=b"{}") sites.add_row(position=1, ancestral_state="C", metadata=b"{'id':1}") mutations.add_row(site=0, node=0, derived_state="T") mutations.add_row(site=1, node=0, derived_state="G") return msprime.load_tables( nodes=nodes, edges=edges, sites=sites, mutations=mutations)
def test_nodes(self): nodes = msprime.NodeTable() edges = msprime.EdgeTable() metadata = ExampleMetadata(one="node1", two="node2") pickled = pickle.dumps(metadata) nodes.add_row(time=0.125, metadata=pickled) ts = msprime.load_tables(nodes=nodes, edges=edges, sequence_length=1) node = ts.node(0) self.assertEqual(node.time, 0.125) self.assertEqual(node.metadata, pickled) unpickled = pickle.loads(node.metadata) self.assertEqual(unpickled.one, metadata.one) self.assertEqual(unpickled.two, metadata.two)
def test_nodes(self): nodes = msprime.NodeTable() edges = msprime.EdgeTable() builder = pjs.ObjectBuilder(json.loads(self.schema)) ns = builder.build_classes() metadata = ns.ExampleMetadata(one="node1", two="node2") encoded = json.dumps(metadata.as_dict()).encode() nodes.add_row(time=0.125, metadata=encoded) ts = msprime.load_tables(nodes=nodes, edges=edges, sequence_length=1) node = ts.node(0) self.assertEqual(node.time, 0.125) self.assertEqual(node.metadata, encoded) decoded = ns.ExampleMetadata.from_json(node.metadata.decode()) self.assertEqual(decoded.one, metadata.one) self.assertEqual(decoded.two, metadata.two)
def __init__(self, gc_interval, trees=None): """ :param gc_interval: Garbage collection interval :param trees: An instance of :class:`msprime.TreeSequence` """ self.gc_interval = gc_interval self.last_gc_time = 0.0 self.__nodes = msprime.NodeTable() self.__edges = msprime.EdgeTable() self.__process = True if trees is not None: self.__process = False trees.dump_tables(nodes=self.__nodes, edges=self.__edges) self.__time_sorting = 0.0 self.__time_appending = 0.0 self.__time_simplifying = 0.0 self.__time_prepping = 0.0
def test_sites(self): nodes = msprime.NodeTable() edges = msprime.EdgeTable() sites = msprime.SiteTable() mutations = msprime.MutationTable() metadata = ExampleMetadata(one="node1", two="node2") pickled = pickle.dumps(metadata) sites.add_row(position=0.1, ancestral_state="A", metadata=pickled) ts = msprime.load_tables( nodes=nodes, edges=edges, sites=sites, mutations=mutations, sequence_length=1) site = ts.site(0) self.assertEqual(site.position, 0.1) self.assertEqual(site.ancestral_state, "A") self.assertEqual(site.metadata, pickled) unpickled = pickle.loads(site.metadata) self.assertEqual(unpickled.one, metadata.one) self.assertEqual(unpickled.two, metadata.two)
def permute_nodes(ts, node_map): """ Returns a copy of the specified tree sequence such that the nodes are permuted according to the specified map. """ # Mapping from nodes in the new tree sequence back to nodes in the original reverse_map = [0 for _ in node_map] for j in range(ts.num_nodes): reverse_map[node_map[j]] = j old_nodes = list(ts.nodes()) new_nodes = msprime.NodeTable() for j in range(ts.num_nodes): old_node = old_nodes[reverse_map[j]] new_nodes.add_row(flags=old_node.flags, metadata=old_node.metadata, population=old_node.population, time=old_node.time) new_edges = msprime.EdgeTable() for edge in ts.edges(): new_edges.add_row(left=edge.left, right=edge.right, parent=node_map[edge.parent], child=node_map[edge.child]) new_sites = msprime.SiteTable() new_mutations = msprime.MutationTable() for site in ts.sites(): new_sites.add_row(position=site.position, ancestral_state=site.ancestral_state) for mutation in site.mutations: new_mutations.add_row(site=site.id, derived_state=mutation.derived_state, node=node_map[mutation.node]) msprime.sort_tables(nodes=new_nodes, edges=new_edges, sites=new_sites, mutations=new_mutations) provenances = ts.dump_tables().provenances add_provenance(provenances, "permute_nodes") return msprime.load_tables(nodes=new_nodes, edges=new_edges, sites=new_sites, mutations=new_mutations, provenances=provenances)
def test_mutations(self): nodes = msprime.NodeTable() edges = msprime.EdgeTable() sites = msprime.SiteTable() mutations = msprime.MutationTable() metadata = ExampleMetadata(one="node1", two="node2") pickled = pickle.dumps(metadata) nodes.add_row(time=0) sites.add_row(position=0.1, ancestral_state="A") mutations.add_row(site=0, node=0, derived_state="T", metadata=pickled) ts = msprime.load_tables( nodes=nodes, edges=edges, sites=sites, mutations=mutations, sequence_length=1) mutation = ts.site(0).mutations[0] self.assertEqual(mutation.site, 0) self.assertEqual(mutation.node, 0) self.assertEqual(mutation.derived_state, "T") self.assertEqual(mutation.metadata, pickled) unpickled = pickle.loads(mutation.metadata) self.assertEqual(unpickled.one, metadata.one) self.assertEqual(unpickled.two, metadata.two)
def wright_fisher(N, T, simplify_interval=1): """ An implementation of algorithm W where we simplify after every generation. The goal here is to measure the number of edges in the tree sequence representing the history as a function of time. For simplicity we assume that the genome length L = 1 and the probability of death delta = 1. """ L = 1 edges = msprime.EdgeTable() nodes = msprime.NodeTable() P = [j for j in range(N)] for j in range(N): nodes.add_row(time=T, flags=1) t = T S = np.zeros(T, dtype=int) while t > 0: t -= 1 Pp = [P[j] for j in range(N)] for j in range(N): n = len(nodes) nodes.add_row(time=t, flags=1) Pp[j] = n a = random.randint(0, N - 1) b = random.randint(0, N - 1) x = random.uniform(0, L) edges.add_row(0, x, P[a], n) edges.add_row(x, L, P[b], n) P = Pp if t % simplify_interval == 0: msprime.sort_tables(nodes=nodes, edges=edges) msprime.simplify_tables(Pp, nodes, edges) P = list(range(N)) S[T - t - 1] = len(edges) # We will always simplify at t = 0, so no need for special case at the end return msprime.load_tables(nodes=nodes, edges=edges), S
def get_empty_tree(self): nodes = msprime.NodeTable() edges = msprime.EdgeTable() ts = msprime.load_tables(nodes=nodes, edges=edges, sequence_length=1) return next(ts.trees())
def _load_legacy_hdf5_v10(root, remove_duplicate_positions=False): # We cannot have duplicate positions in v10, so this parameter is ignored nodes_group = root["nodes"] nodes = msprime.NodeTable() metadata = None metadata_offset = None if "metadata" in nodes_group: metadata = nodes_group["metadata"] metadata_offset = nodes_group["metadata_offset"] nodes.set_columns(flags=nodes_group["flags"], population=nodes_group["population"], time=nodes_group["time"], metadata=metadata, metadata_offset=metadata_offset) edges_group = root["edges"] edges = msprime.EdgeTable() edges.set_columns(left=edges_group["left"], right=edges_group["right"], parent=edges_group["parent"], child=edges_group["child"]) migrations_group = root["migrations"] migrations = msprime.MigrationTable() if "left" in migrations_group: migrations.set_columns(left=migrations_group["left"], right=migrations_group["right"], node=migrations_group["node"], source=migrations_group["source"], dest=migrations_group["dest"], time=migrations_group["time"]) sites_group = root["sites"] sites = msprime.SiteTable() if "position" in sites_group: metadata = None metadata_offset = None if "metadata" in sites_group: metadata = sites_group["metadata"] metadata_offset = sites_group["metadata_offset"] sites.set_columns( position=sites_group["position"], ancestral_state=sites_group["ancestral_state"], ancestral_state_offset=sites_group["ancestral_state_offset"], metadata=metadata, metadata_offset=metadata_offset) mutations_group = root["mutations"] mutations = msprime.MutationTable() if "site" in mutations_group: metadata = None metadata_offset = None if "metadata" in mutations_group: metadata = mutations_group["metadata"] metadata_offset = mutations_group["metadata_offset"] mutations.set_columns( site=mutations_group["site"], node=mutations_group["node"], parent=mutations_group["parent"], derived_state=mutations_group["derived_state"], derived_state_offset=mutations_group["derived_state_offset"], metadata=metadata, metadata_offset=metadata_offset) provenances_group = root["provenances"] provenances = msprime.ProvenanceTable() if "timestamp" in provenances_group: timestamp = provenances_group["timestamp"] timestamp_offset = provenances_group["timestamp_offset"] if "record" in provenances_group: record = provenances_group["record"] record_offset = provenances_group["record_offset"] else: record = np.empty_like(timestamp) record_offset = np.zeros_like(timestamp_offset) provenances.set_columns(timestamp=timestamp, timestamp_offset=timestamp_offset, record=record, record_offset=record_offset) provenances.add_row(_get_upgrade_provenance(root)) return msprime.load_tables(nodes=nodes, edges=edges, migrations=migrations, sites=sites, mutations=mutations, provenances=provenances)
expensive_check(popsize, edges, nodes) max_gen = nodes['generation'].max() assert (int(max_gen) == 20 * popsize) # Convert node times from forwards to backwards nodes['generation'] = nodes['generation'] - max_gen nodes['generation'] = nodes['generation'] * -1.0 # Construct and populate msprime's tables flags = np.empty([len(nodes)], dtype=np.uint32) flags.fill(1) prior_ts = msprime.simulate(2 * popsize) nt = msprime.NodeTable() es = msprime.EdgeTable() prior_ts.dump_tables(nodes=nt, edges=es) nt.set_columns( flags=nt.flags, #[2 * popsize:], population=nt.population, #[2 * popsize:], time=nt.time + ngens + 1) node_offset = nt.num_rows nt.append_columns(flags=flags, population=nodes['population'] + node_offset, time=nodes['generation']) es.append_columns(left=edges['left'], right=edges['right'], parent=edges['parent'] + node_offset, child=edges['child'] + node_offset)
def wfrec(nsam, rho, nsites, theta): samples = [] for i in range(nsam): samples.append(it.IntervalTree([it.Interval(0, nsites)])) links = np.array([sumIntervalTree(i) for i in samples], dtype=np.int) nlinks = links.sum() n = nsam rbp = rho / float(nsites - 1) t = 0.0 nodes = msprime.NodeTable() edges = msprime.EdgeTable() nodes.set_columns(time=np.zeros(nsam), flags=np.ones(nsam, dtype=np.uint32)) sample_indexes = [i for i in range(len(samples))] next_index = len(sample_indexes) while (n > 1): rcoal = float(n * (n - 1)) rrec = rbp * float(nlinks) iscoal = bool(np.random.random_sample(1)[0] < rcoal / (rcoal + rrec)) t += np.random.exponential(4. / (rcoal + rrec), 1)[0] assert len(samples) == len(links), "sample/link error" if iscoal is True: chroms = np.sort(np.random.choice(n, 2, replace=False)) c1 = chroms[0] c2 = chroms[1] nodes.add_row(time=t, flags=msprime.NODE_IS_SAMPLE) for i in samples[c1]: edges.add_row(left=i[0], right=i[1], parent=next_index, child=sample_indexes[c1]) edges.add_row(left=i[0], right=i[1], parent=next_index, child=sample_indexes[c2]) newchrom = it.IntervalTree() # Merge intervals of the two chromosomes # and remove overlaps for i in samples[c1]: newchrom.append(i) for i in samples[c2]: newchrom.append(i) newchrom.merge_overlaps() samples.pop(c2) samples.pop(c1) samples.append(newchrom) sample_indexes.pop(c2) sample_indexes.pop(c1) sample_indexes.append(next_index) next_index += 1 n -= 1 else: # Pick a chrom proportional to # its total size: chrom = np.random.choice(len(sample_indexes), 1, p=links / links.sum())[0] mnpos = min( [i for j in samples[chrom] for i in j if i is not None]) mxpos = max( [i for j in samples[chrom] for i in j if i is not None]) pos = np.random.randint(mnpos, mxpos) samples[chrom].chop(pos, pos) tc = it.IntervalTree([i for i in samples[chrom] if i[0] >= pos]) samples[chrom].remove_overlap(pos, nsites) samples.append(tc) sample_indexes.append(next_index) next_index += 1 n += 1 assert all([len(i) > 0 for i in samples]), "empty IntervalTree" assert len(samples) == len(sample_indexes), "sample/sample_index error" links = np.array([sumIntervalTree(i) for i in samples], dtype=np.int) nlinks = links.sum() assert len(samples) == len(links), "sample/link error 2" for i in range(len(edges)): assert edges[i].parent < len(nodes), "parent error" assert edges[i].child < len(nodes), "child error" msprime.sort_tables(nodes=nodes, edges=edges) return msprime.load_tables(nodes=nodes, edges=edges)
def run(self, ngens): nodes = msprime.NodeTable() edges = msprime.EdgeTable() migrations = msprime.MigrationTable() sites = msprime.SiteTable() mutations = msprime.MutationTable() provenances = msprime.ProvenanceTable() if self.deep_history: # initial population init_ts = msprime.simulate(self.N, recombination_rate=1.0) init_ts.dump_tables(nodes=nodes, edges=edges) nodes.set_columns(time=nodes.time + ngens, flags=nodes.flags) else: for _ in range(self.N): nodes.add_row(time=ngens) pop = list(range(self.N)) for t in range(ngens - 1, -1, -1): if self.debug: print("t:", t) print("pop:", pop) dead = [random.random() > self.survival for k in pop] # sample these first so that all parents are from the previous gen new_parents = [(random.choice(pop), random.choice(pop)) for k in range(sum(dead))] k = 0 if self.debug: print("Replacing", sum(dead), "individuals.") for j in range(self.N): if dead[j]: # this is: offspring ID, lparent, rparent, breakpoint offspring = nodes.num_rows nodes.add_row(time=t) lparent, rparent = new_parents[k] k += 1 bp = self.random_breakpoint() if self.debug: print("--->", offspring, lparent, rparent, bp) pop[j] = offspring if bp > 0.0: edges.add_row(left=0.0, right=bp, parent=lparent, child=offspring) if bp < 1.0: edges.add_row(left=bp, right=1.0, parent=rparent, child=offspring) if self.debug: print("Done! Final pop:") print(pop) flags = [(msprime.NODE_IS_SAMPLE if u in pop else 0) for u in range(nodes.num_rows)] nodes.set_columns(time=nodes.time, flags=flags) if self.debug: print("Done.") print("Nodes:") print(nodes) print("Edges:") print(edges) return msprime.TableCollection(nodes, edges, migrations, sites, mutations, provenances)
def null_tree_sequence(): return msprime.load_tables(nodes=msprime.NodeTable(), edges=msprime.EdgeTable())
def __init__(self, gc_interval=None): self.__nodes = msprime.NodeTable() self.__edges = msprime.EdgeTable() self.gc_interval = gc_interval self.last_gc_time = 0
def __init__(self, sample_size, num_loci, recombination_rate, migration_matrix, sample_configuration, population_growth_rates, population_sizes, population_growth_rate_changes, population_size_changes, migration_matrix_element_changes, bottlenecks, model='hudson', max_segments=100): # Must be a square matrix. N = len(migration_matrix) assert len(sample_configuration) == N assert len(population_growth_rates) == N assert len(population_sizes) == N for j in range(N): assert N == len(migration_matrix[j]) assert migration_matrix[j][j] == 0 assert sum(sample_configuration) == sample_size self.model = model self.n = sample_size self.m = num_loci self.r = recombination_rate self.migration_matrix = migration_matrix self.max_segments = max_segments self.segment_stack = [] self.segments = [None for j in range(self.max_segments + 1)] for j in range(self.max_segments): s = Segment(j + 1) self.segments[j + 1] = s self.segment_stack.append(s) self.P = [Population(id_) for id_ in range(N)] self.L = FenwickTree(self.max_segments) self.S = bintrees.AVLTree() # The output tree sequence. self.nodes = msprime.NodeTable() self.edges = msprime.EdgeTable() self.edge_buffer = [] for pop_index in range(N): sample_size = sample_configuration[pop_index] self.P[pop_index].set_start_size(population_sizes[pop_index]) self.P[pop_index].set_growth_rate( population_growth_rates[pop_index], 0) for k in range(sample_size): j = len(self.nodes) x = self.alloc_segment(0, self.m, j, pop_index) self.L.set_value(x.index, self.m - 1) self.P[pop_index].add(x) self.nodes.add_row(flags=msprime.NODE_IS_SAMPLE, time=0, population=pop_index) j += 1 self.S[0] = self.n self.S[self.m] = -1 self.t = 0 self.num_ca_events = 0 self.num_re_events = 0 self.modifier_events = [(sys.float_info.max, None, None)] for time, pop_id, new_size in population_size_changes: self.modifier_events.append( (time, self.change_population_size, (int(pop_id), new_size))) for time, pop_id, new_rate in population_growth_rate_changes: self.modifier_events.append( (time, self.change_population_growth_rate, (int(pop_id), new_rate, time))) for time, pop_i, pop_j, new_rate in migration_matrix_element_changes: self.modifier_events.append( (time, self.change_migration_matrix_element, (int(pop_i), int(pop_j), new_rate))) for time, pop_id, intensity in bottlenecks: self.modifier_events.append( (time, self.bottleneck_event, (int(pop_id), intensity))) self.modifier_events.sort()
def _load_legacy_hdf5_v2(root, remove_duplicate_positions): # Get the coalescence records trees_group = root["trees"] old_timestamp = datetime.datetime.min.isoformat() provenances = msprime.ProvenanceTable() provenances.add_row(timestamp=old_timestamp, record=_get_v2_provenance("generate_trees", trees_group.attrs)) num_rows = trees_group["node"].shape[0] index = np.arange(num_rows, dtype=int) parent = np.zeros(2 * num_rows, dtype=np.int32) parent[2 * index] = trees_group["node"] parent[2 * index + 1] = trees_group["node"] left = np.zeros(2 * num_rows, dtype=np.float64) left[2 * index] = trees_group["left"] left[2 * index + 1] = trees_group["left"] right = np.zeros(2 * num_rows, dtype=np.float64) right[2 * index] = trees_group["right"] right[2 * index + 1] = trees_group["right"] child = np.array(trees_group["children"], dtype=np.int32).flatten() edges = msprime.EdgeTable() edges.set_columns(left=left, right=right, parent=parent, child=child) cr_node = np.array(trees_group["node"], dtype=np.int32) num_nodes = max(np.max(child), np.max(cr_node)) + 1 sample_size = np.min(cr_node) flags = np.zeros(num_nodes, dtype=np.uint32) population = np.zeros(num_nodes, dtype=np.int32) time = np.zeros(num_nodes, dtype=np.float64) flags[:sample_size] = msprime.NODE_IS_SAMPLE cr_population = np.array(trees_group["population"], dtype=np.int32) cr_time = np.array(trees_group["time"]) time[cr_node] = cr_time population[cr_node] = cr_population if "samples" in root: samples_group = root["samples"] population[:sample_size] = samples_group["population"] if "time" in samples_group: time[:sample_size] = samples_group["time"] nodes = msprime.NodeTable() nodes.set_columns(flags=flags, population=population, time=time) sites = msprime.SiteTable() mutations = msprime.MutationTable() if "mutations" in root: mutations_group = root["mutations"] _convert_hdf5_mutations(mutations_group, sites, mutations, remove_duplicate_positions) provenances.add_row(timestamp=old_timestamp, record=_get_v2_provenance("generate_mutations", mutations_group.attrs)) provenances.add_row(_get_upgrade_provenance(root)) msprime.sort_tables(nodes=nodes, edges=edges, sites=sites, mutations=mutations) return msprime.load_tables(nodes=nodes, edges=edges, sites=sites, mutations=mutations, provenances=provenances)
def get_tree_sequence(self, rescale_positions=True, all_sites=False): """ Returns the current state of the build tree sequence. All samples and ancestors will have the sample node flag set. """ # TODO Change the API here to ask whether we want a final tree sequence # or not. In the latter case we also need to translate the ancestral # and derived states to the input values. tsb = self.tree_sequence_builder flags, time = tsb.dump_nodes() nodes = msprime.NodeTable() nodes.set_columns(flags=flags, time=time) left, right, parent, child = tsb.dump_edges() if rescale_positions: position = self.sample_data.position[:] sequence_length = self.sample_data.sequence_length if sequence_length is None or sequence_length < position[-1]: sequence_length = position[-1] + 1 # Subset down to the variants. position = position[self.sample_data.variant_site[:]] x = np.hstack([position, [sequence_length]]) x[0] = 0 left = x[left] right = x[right] else: position = np.arange(tsb.num_sites) sequence_length = max(1, tsb.num_sites) edges = msprime.EdgeTable() edges.set_columns(left=left, right=right, parent=parent, child=child) sites = msprime.SiteTable() sites.set_columns( position=position, ancestral_state=np.zeros(tsb.num_sites, dtype=np.int8) + ord('0'), ancestral_state_offset=np.arange(tsb.num_sites + 1, dtype=np.uint32)) mutations = msprime.MutationTable() site = np.zeros(tsb.num_mutations, dtype=np.int32) node = np.zeros(tsb.num_mutations, dtype=np.int32) parent = np.zeros(tsb.num_mutations, dtype=np.int32) derived_state = np.zeros(tsb.num_mutations, dtype=np.int8) site, node, derived_state, parent = tsb.dump_mutations() derived_state += ord('0') mutations.set_columns(site=site, node=node, derived_state=derived_state, derived_state_offset=np.arange( tsb.num_mutations + 1, dtype=np.uint32), parent=parent) if all_sites: # Append the sites and mutations for each singleton. num_singletons = self.sample_data.num_singleton_sites singleton_site = self.sample_data.singleton_site[:] singleton_sample = self.sample_data.singleton_sample[:] pos = self.sample_data.position[:] new_sites = np.arange(len(sites), len(sites) + num_singletons, dtype=np.int32) sites.append_columns( position=pos[singleton_site], ancestral_state=np.zeros(num_singletons, dtype=np.int8) + ord('0'), ancestral_state_offset=np.arange(num_singletons + 1, dtype=np.uint32)) mutations.append_columns( site=new_sites, node=self.sample_ids[singleton_sample], derived_state=np.zeros(num_singletons, dtype=np.int8) + ord('1'), derived_state_offset=np.arange(num_singletons + 1, dtype=np.uint32)) # Get the invariant sites num_invariants = self.sample_data.num_invariant_sites invariant_site = self.sample_data.invariant_site[:] sites.append_columns( position=pos[invariant_site], ancestral_state=np.zeros(num_invariants, dtype=np.int8) + ord('0'), ancestral_state_offset=np.arange(num_invariants + 1, dtype=np.uint32)) msprime.sort_tables(nodes, edges, sites=sites, mutations=mutations) return msprime.load_tables(nodes=nodes, edges=edges, sites=sites, mutations=mutations, sequence_length=sequence_length)
def run_simplify_num_edges_benchmark(args): ts = msprime.load(args.file) np.random.seed(1) print("num_nodes = ", ts.num_nodes) print("num_edges = ", ts.num_edges) num_slices = 10 tables = ts.dump_tables() nodes = tables.nodes edges = tables.edges node_time = nodes.time left = edges.left right = edges.right parent = edges.parent child = edges.child size = left.nbytes + right.nbytes + parent.nbytes + child.nbytes print("Total edge size = ", size / 1024**3, "GiB") sample_sizes = [10, 100, 1000] num_sample_sizes = len(sample_sizes) num_edges = np.zeros(num_slices * num_sample_sizes) simplify_time = np.zeros(num_slices * num_sample_sizes) sample_size = np.zeros(num_slices * num_sample_sizes) slice_size = ts.num_edges // num_slices j = 0 for N in sample_sizes: for start in range(ts.num_edges - slice_size, 0, -slice_size): max_node = np.max(child[start:]) samples = np.arange(max_node - N, max_node, dtype=np.int32) subset_nodes = msprime.NodeTable() subset_nodes.set_columns(time=node_time[:max_node + 1], flags=np.ones(max_node + 1, dtype=np.uint32)) subset_edges = msprime.EdgeTable() subset_edges.set_columns(left=left[start:], right=right[start:], parent=parent[start:], child=child[start:]) before = time.process_time() msprime.simplify_tables(samples=samples, nodes=subset_nodes, edges=subset_edges) duration = time.process_time() - before num_edges[j] = ts.num_edges - start simplify_time[j] = duration sample_size[j] = N print(N, num_edges[j], duration, num_edges[j] / duration, "per second") j += 1 df = pd.DataFrame({ "sample_size": sample_size, "num_edges": num_edges, "time": simplify_time }) df.to_csv("data/simplify_num_edges.dat") for N in sample_sizes: index = sample_size == N plt.plot(num_edges[index], simplify_time[index], marker="o") plt.xlabel("num edges") plt.ylabel("Time to simplify (s)") plt.savefig("simplify_num_edges.png")