示例#1
0
 def __init__(self, ts, sample, filter_zero_mutation_sites=True):
     self.ts = ts
     self.n = len(sample)
     self.sequence_length = ts.sequence_length
     self.filter_zero_mutation_sites = filter_zero_mutation_sites
     self.num_mutations = ts.num_mutations
     self.input_sites = list(ts.sites())
     self.A_head = [None for _ in range(ts.num_nodes)]
     self.A_tail = [None for _ in range(ts.num_nodes)]
     self.mutation_table = msprime.MutationTable(ts.num_mutations)
     self.node_table = msprime.NodeTable(ts.num_nodes)
     self.edge_table = msprime.EdgeTable(ts.num_edges)
     self.site_table = msprime.SiteTable(ts.num_sites)
     self.mutation_table = msprime.MutationTable(ts.num_mutations)
     self.edge_buffer = {}
     self.node_id_map = np.zeros(ts.num_nodes, dtype=np.int32) - 1
     self.mutation_node_map = [-1 for _ in range(self.num_mutations)]
     self.samples = set(sample)
     for sample_id in sample:
         output_id = self.record_node(sample_id, is_sample=True)
         self.add_ancestry(sample_id, 0, self.sequence_length, output_id)
     # We keep a map of input nodes to mutations.
     self.mutation_map = [[] for _ in range(ts.num_nodes)]
     position = ts.tables.sites.position
     site = ts.tables.mutations.site
     node = ts.tables.mutations.node
     for mutation_id in range(ts.num_mutations):
         site_position = position[site[mutation_id]]
         self.mutation_map[node[mutation_id]].append(
             (site_position, mutation_id))
示例#2
0
 def __init__(self, ts, sample, filter_zero_mutation_sites=True):
     self.ts = ts
     self.n = len(sample)
     self.sequence_length = ts.sequence_length
     self.filter_zero_mutation_sites = filter_zero_mutation_sites
     self.num_mutations = ts.num_mutations
     self.input_sites = list(ts.sites())
     # A maps input node IDs to the extant ancestor chain. Once the algorithm
     # has processed the ancestors, they are are removed from the map.
     self.A = {}
     self.mutation_table = msprime.MutationTable(ts.num_mutations)
     self.node_table = msprime.NodeTable(ts.num_nodes)
     self.edge_table = msprime.EdgeTable(ts.num_edges)
     self.site_table = msprime.SiteTable(ts.num_sites)
     self.mutation_table = msprime.MutationTable(ts.num_mutations)
     self.edge_buffer = []
     self.node_id_map = {}
     self.mutation_node_map = [-1 for _ in range(self.num_mutations)]
     self.samples = set(sample)
     for sample_id in sample:
         self.insert_sample(sample_id)
     # We keep a map of input nodes to mutations.
     self.mutation_map = [[] for _ in range(ts.num_nodes)]
     position = ts.tables.sites.position
     site = ts.tables.mutations.site
     node = ts.tables.mutations.node
     for mutation_id in range(ts.num_mutations):
         site_position = position[site[mutation_id]]
         self.mutation_map[node[mutation_id]].append(
             (site_position, mutation_id))
示例#3
0
def wright_fisher(N, delta, L, T):
    """
    Direct implementation of Algorithm W.
    """
    edges = msprime.EdgeTable()
    tau = []
    P = [j for j in range(N)]
    for j in range(N):
        tau.append(T)
    t = T
    n = N
    while t > 0:
        t -= 1
        j = 0
        Pp = [P[j] for j in range(N)]
        while j < N:
            if random.random() < delta:
                Pp[j] = n
                tau.append(t)
                a = random.randint(0, N - 1)
                b = random.randint(0, N - 1)
                x = random.uniform(0, L)
                edges.add_row(0, x, P[a], n)
                edges.add_row(x, L, P[b], n)
                n += 1
            j += 1
        P = Pp
    nodes = msprime.NodeTable()
    P = set(P)
    for j in range(n):
        nodes.add_row(time=tau[j], flags=int(j in P))
    msprime.sort_tables(nodes=nodes, edges=edges)
    return msprime.load_tables(nodes=nodes, edges=edges)
示例#4
0
def _load_legacy_hdf5_v3(root, remove_duplicate_positions):
    # get the trees group for the records and samples
    trees_group = root["trees"]
    nodes_group = trees_group["nodes"]
    time = np.array(nodes_group["time"])

    breakpoints = np.array(trees_group["breakpoints"])
    records_group = trees_group["records"]
    left_indexes = np.array(records_group["left"])
    right_indexes = np.array(records_group["right"])
    record_node = np.array(records_group["node"], dtype=np.int32)
    num_nodes = time.shape[0]
    sample_size = np.min(record_node)
    flags = np.zeros(num_nodes, dtype=np.uint32)
    flags[:sample_size] = msprime.NODE_IS_SAMPLE

    children_length = np.array(records_group["num_children"], dtype=np.uint32)
    total_rows = np.sum(children_length)
    left = np.zeros(total_rows, dtype=np.float64)
    right = np.zeros(total_rows, dtype=np.float64)
    parent = np.zeros(total_rows, dtype=np.int32)
    record_left = breakpoints[left_indexes]
    record_right = breakpoints[right_indexes]
    k = 0
    for j in range(left_indexes.shape[0]):
        for _ in range(children_length[j]):
            left[k] = record_left[j]
            right[k] = record_right[j]
            parent[k] = record_node[j]
            k += 1
    nodes = msprime.NodeTable()
    nodes.set_columns(flags=flags,
                      time=nodes_group["time"],
                      population=nodes_group["population"])
    edges = msprime.EdgeTable()
    edges.set_columns(left=left,
                      right=right,
                      parent=parent,
                      child=records_group["children"])
    sites = msprime.SiteTable()
    mutations = msprime.MutationTable()
    if "mutations" in root:
        _convert_hdf5_mutations(root["mutations"], sites, mutations,
                                remove_duplicate_positions)
    old_timestamp = datetime.datetime.min.isoformat()
    provenances = msprime.ProvenanceTable()
    if "provenance" in root:
        for record in root["provenance"]:
            provenances.add_row(timestamp=old_timestamp, record=record)
    provenances.add_row(_get_upgrade_provenance(root))
    msprime.sort_tables(nodes=nodes,
                        edges=edges,
                        sites=sites,
                        mutations=mutations)
    return msprime.load_tables(nodes=nodes,
                               edges=edges,
                               sites=sites,
                               mutations=mutations,
                               provenances=provenances)
示例#5
0
 def store_output(self):
     if self.num_ancestors > 0:
         ts = self.get_tree_sequence(rescale_positions=False)
     else:
         # Allocate an empty tree sequence.
         ts = msprime.load_tables(nodes=msprime.NodeTable(),
                                  edges=msprime.EdgeTable(),
                                  sequence_length=1)
     if self.output_path is not None:
         ts.dump(self.output_path)
     return ts
示例#6
0
def node_metadata_example():
    ts = msprime.simulate(
        sample_size=100, recombination_rate=0.1, length=10, random_seed=1)
    nodes = msprime.NodeTable()
    edges = msprime.EdgeTable()
    ts.dump_tables(nodes=nodes, edges=edges)
    new_nodes = msprime.NodeTable()
    metadatas = ["n_{}".format(u) for u in range(ts.num_nodes)]
    packed, offset = msprime.pack_strings(metadatas)
    new_nodes.set_columns(
        metadata=packed, metadata_offset=offset, flags=nodes.flags, time=nodes.time)
    return msprime.load_tables(nodes=new_nodes, edges=edges)
示例#7
0
def general_mutation_example():
    ts = msprime.simulate(10, recombination_rate=1, length=10, random_seed=2)
    nodes = msprime.NodeTable()
    edges = msprime.EdgeTable()
    ts.dump_tables(nodes=nodes, edges=edges)
    sites = msprime.SiteTable()
    mutations = msprime.MutationTable()
    sites.add_row(position=0, ancestral_state="A", metadata=b"{}")
    sites.add_row(position=1, ancestral_state="C", metadata=b"{'id':1}")
    mutations.add_row(site=0, node=0, derived_state="T")
    mutations.add_row(site=1, node=0, derived_state="G")
    return msprime.load_tables(
        nodes=nodes, edges=edges, sites=sites, mutations=mutations)
示例#8
0
 def test_nodes(self):
     nodes = msprime.NodeTable()
     edges = msprime.EdgeTable()
     metadata = ExampleMetadata(one="node1", two="node2")
     pickled = pickle.dumps(metadata)
     nodes.add_row(time=0.125, metadata=pickled)
     ts = msprime.load_tables(nodes=nodes, edges=edges, sequence_length=1)
     node = ts.node(0)
     self.assertEqual(node.time, 0.125)
     self.assertEqual(node.metadata, pickled)
     unpickled = pickle.loads(node.metadata)
     self.assertEqual(unpickled.one, metadata.one)
     self.assertEqual(unpickled.two, metadata.two)
示例#9
0
 def test_nodes(self):
     nodes = msprime.NodeTable()
     edges = msprime.EdgeTable()
     builder = pjs.ObjectBuilder(json.loads(self.schema))
     ns = builder.build_classes()
     metadata = ns.ExampleMetadata(one="node1", two="node2")
     encoded = json.dumps(metadata.as_dict()).encode()
     nodes.add_row(time=0.125, metadata=encoded)
     ts = msprime.load_tables(nodes=nodes, edges=edges, sequence_length=1)
     node = ts.node(0)
     self.assertEqual(node.time, 0.125)
     self.assertEqual(node.metadata, encoded)
     decoded = ns.ExampleMetadata.from_json(node.metadata.decode())
     self.assertEqual(decoded.one, metadata.one)
     self.assertEqual(decoded.two, metadata.two)
示例#10
0
 def __init__(self, gc_interval, trees=None):
     """
     :param gc_interval: Garbage collection interval
     :param trees: An instance of :class:`msprime.TreeSequence`
     """
     self.gc_interval = gc_interval
     self.last_gc_time = 0.0
     self.__nodes = msprime.NodeTable()
     self.__edges = msprime.EdgeTable()
     self.__process = True
     if trees is not None:
         self.__process = False
         trees.dump_tables(nodes=self.__nodes, edges=self.__edges)
     self.__time_sorting = 0.0
     self.__time_appending = 0.0
     self.__time_simplifying = 0.0
     self.__time_prepping = 0.0
示例#11
0
 def test_sites(self):
     nodes = msprime.NodeTable()
     edges = msprime.EdgeTable()
     sites = msprime.SiteTable()
     mutations = msprime.MutationTable()
     metadata = ExampleMetadata(one="node1", two="node2")
     pickled = pickle.dumps(metadata)
     sites.add_row(position=0.1, ancestral_state="A", metadata=pickled)
     ts = msprime.load_tables(
         nodes=nodes, edges=edges, sites=sites, mutations=mutations,
         sequence_length=1)
     site = ts.site(0)
     self.assertEqual(site.position, 0.1)
     self.assertEqual(site.ancestral_state, "A")
     self.assertEqual(site.metadata, pickled)
     unpickled = pickle.loads(site.metadata)
     self.assertEqual(unpickled.one, metadata.one)
     self.assertEqual(unpickled.two, metadata.two)
示例#12
0
def permute_nodes(ts, node_map):
    """
    Returns a copy of the specified tree sequence such that the nodes are
    permuted according to the specified map.
    """
    # Mapping from nodes in the new tree sequence back to nodes in the original
    reverse_map = [0 for _ in node_map]
    for j in range(ts.num_nodes):
        reverse_map[node_map[j]] = j
    old_nodes = list(ts.nodes())
    new_nodes = msprime.NodeTable()
    for j in range(ts.num_nodes):
        old_node = old_nodes[reverse_map[j]]
        new_nodes.add_row(flags=old_node.flags,
                          metadata=old_node.metadata,
                          population=old_node.population,
                          time=old_node.time)
    new_edges = msprime.EdgeTable()
    for edge in ts.edges():
        new_edges.add_row(left=edge.left,
                          right=edge.right,
                          parent=node_map[edge.parent],
                          child=node_map[edge.child])
    new_sites = msprime.SiteTable()
    new_mutations = msprime.MutationTable()
    for site in ts.sites():
        new_sites.add_row(position=site.position,
                          ancestral_state=site.ancestral_state)
        for mutation in site.mutations:
            new_mutations.add_row(site=site.id,
                                  derived_state=mutation.derived_state,
                                  node=node_map[mutation.node])
    msprime.sort_tables(nodes=new_nodes,
                        edges=new_edges,
                        sites=new_sites,
                        mutations=new_mutations)
    provenances = ts.dump_tables().provenances
    add_provenance(provenances, "permute_nodes")
    return msprime.load_tables(nodes=new_nodes,
                               edges=new_edges,
                               sites=new_sites,
                               mutations=new_mutations,
                               provenances=provenances)
示例#13
0
 def test_mutations(self):
     nodes = msprime.NodeTable()
     edges = msprime.EdgeTable()
     sites = msprime.SiteTable()
     mutations = msprime.MutationTable()
     metadata = ExampleMetadata(one="node1", two="node2")
     pickled = pickle.dumps(metadata)
     nodes.add_row(time=0)
     sites.add_row(position=0.1, ancestral_state="A")
     mutations.add_row(site=0, node=0, derived_state="T", metadata=pickled)
     ts = msprime.load_tables(
         nodes=nodes, edges=edges, sites=sites, mutations=mutations,
         sequence_length=1)
     mutation = ts.site(0).mutations[0]
     self.assertEqual(mutation.site, 0)
     self.assertEqual(mutation.node, 0)
     self.assertEqual(mutation.derived_state, "T")
     self.assertEqual(mutation.metadata, pickled)
     unpickled = pickle.loads(mutation.metadata)
     self.assertEqual(unpickled.one, metadata.one)
     self.assertEqual(unpickled.two, metadata.two)
示例#14
0
def wright_fisher(N, T, simplify_interval=1):
    """
    An implementation of algorithm W where we simplify after every generation.
    The goal here is to measure the number of edges in the tree sequence
    representing the history as a function of time.

    For simplicity we assume that the genome length L = 1 and the probability
    of death delta = 1.
    """
    L = 1
    edges = msprime.EdgeTable()
    nodes = msprime.NodeTable()
    P = [j for j in range(N)]
    for j in range(N):
        nodes.add_row(time=T, flags=1)
    t = T
    S = np.zeros(T, dtype=int)
    while t > 0:
        t -= 1
        Pp = [P[j] for j in range(N)]
        for j in range(N):
            n = len(nodes)
            nodes.add_row(time=t, flags=1)
            Pp[j] = n
            a = random.randint(0, N - 1)
            b = random.randint(0, N - 1)
            x = random.uniform(0, L)
            edges.add_row(0, x, P[a], n)
            edges.add_row(x, L, P[b], n)
        P = Pp
        if t % simplify_interval == 0:
            msprime.sort_tables(nodes=nodes, edges=edges)
            msprime.simplify_tables(Pp, nodes, edges)
            P = list(range(N))
        S[T - t - 1] = len(edges)
    # We will always simplify at t = 0, so no need for special case at the end
    return msprime.load_tables(nodes=nodes, edges=edges), S
示例#15
0
 def get_empty_tree(self):
     nodes = msprime.NodeTable()
     edges = msprime.EdgeTable()
     ts = msprime.load_tables(nodes=nodes, edges=edges, sequence_length=1)
     return next(ts.trees())
示例#16
0
def _load_legacy_hdf5_v10(root, remove_duplicate_positions=False):
    # We cannot have duplicate positions in v10, so this parameter is ignored
    nodes_group = root["nodes"]
    nodes = msprime.NodeTable()
    metadata = None
    metadata_offset = None
    if "metadata" in nodes_group:
        metadata = nodes_group["metadata"]
        metadata_offset = nodes_group["metadata_offset"]
    nodes.set_columns(flags=nodes_group["flags"],
                      population=nodes_group["population"],
                      time=nodes_group["time"],
                      metadata=metadata,
                      metadata_offset=metadata_offset)

    edges_group = root["edges"]
    edges = msprime.EdgeTable()
    edges.set_columns(left=edges_group["left"],
                      right=edges_group["right"],
                      parent=edges_group["parent"],
                      child=edges_group["child"])

    migrations_group = root["migrations"]
    migrations = msprime.MigrationTable()
    if "left" in migrations_group:
        migrations.set_columns(left=migrations_group["left"],
                               right=migrations_group["right"],
                               node=migrations_group["node"],
                               source=migrations_group["source"],
                               dest=migrations_group["dest"],
                               time=migrations_group["time"])

    sites_group = root["sites"]
    sites = msprime.SiteTable()
    if "position" in sites_group:
        metadata = None
        metadata_offset = None
        if "metadata" in sites_group:
            metadata = sites_group["metadata"]
            metadata_offset = sites_group["metadata_offset"]
        sites.set_columns(
            position=sites_group["position"],
            ancestral_state=sites_group["ancestral_state"],
            ancestral_state_offset=sites_group["ancestral_state_offset"],
            metadata=metadata,
            metadata_offset=metadata_offset)

    mutations_group = root["mutations"]
    mutations = msprime.MutationTable()
    if "site" in mutations_group:
        metadata = None
        metadata_offset = None
        if "metadata" in mutations_group:
            metadata = mutations_group["metadata"]
            metadata_offset = mutations_group["metadata_offset"]
        mutations.set_columns(
            site=mutations_group["site"],
            node=mutations_group["node"],
            parent=mutations_group["parent"],
            derived_state=mutations_group["derived_state"],
            derived_state_offset=mutations_group["derived_state_offset"],
            metadata=metadata,
            metadata_offset=metadata_offset)

    provenances_group = root["provenances"]
    provenances = msprime.ProvenanceTable()
    if "timestamp" in provenances_group:
        timestamp = provenances_group["timestamp"]
        timestamp_offset = provenances_group["timestamp_offset"]
        if "record" in provenances_group:
            record = provenances_group["record"]
            record_offset = provenances_group["record_offset"]
        else:
            record = np.empty_like(timestamp)
            record_offset = np.zeros_like(timestamp_offset)
        provenances.set_columns(timestamp=timestamp,
                                timestamp_offset=timestamp_offset,
                                record=record,
                                record_offset=record_offset)
    provenances.add_row(_get_upgrade_provenance(root))

    return msprime.load_tables(nodes=nodes,
                               edges=edges,
                               migrations=migrations,
                               sites=sites,
                               mutations=mutations,
                               provenances=provenances)
        expensive_check(popsize, edges, nodes)

    max_gen = nodes['generation'].max()
    assert (int(max_gen) == 20 * popsize)

    # Convert node times from forwards to backwards
    nodes['generation'] = nodes['generation'] - max_gen
    nodes['generation'] = nodes['generation'] * -1.0

    # Construct and populate msprime's tables
    flags = np.empty([len(nodes)], dtype=np.uint32)
    flags.fill(1)

    prior_ts = msprime.simulate(2 * popsize)
    nt = msprime.NodeTable()
    es = msprime.EdgeTable()
    prior_ts.dump_tables(nodes=nt, edges=es)
    nt.set_columns(
        flags=nt.flags,  #[2 * popsize:],
        population=nt.population,  #[2 * popsize:],
        time=nt.time + ngens + 1)
    node_offset = nt.num_rows

    nt.append_columns(flags=flags,
                      population=nodes['population'] + node_offset,
                      time=nodes['generation'])

    es.append_columns(left=edges['left'],
                      right=edges['right'],
                      parent=edges['parent'] + node_offset,
                      child=edges['child'] + node_offset)
示例#18
0
def wfrec(nsam, rho, nsites, theta):
    samples = []
    for i in range(nsam):
        samples.append(it.IntervalTree([it.Interval(0, nsites)]))

    links = np.array([sumIntervalTree(i) for i in samples], dtype=np.int)
    nlinks = links.sum()

    n = nsam
    rbp = rho / float(nsites - 1)
    t = 0.0

    nodes = msprime.NodeTable()
    edges = msprime.EdgeTable()

    nodes.set_columns(time=np.zeros(nsam),
                      flags=np.ones(nsam, dtype=np.uint32))

    sample_indexes = [i for i in range(len(samples))]
    next_index = len(sample_indexes)

    while (n > 1):
        rcoal = float(n * (n - 1))
        rrec = rbp * float(nlinks)

        iscoal = bool(np.random.random_sample(1)[0] < rcoal / (rcoal + rrec))
        t += np.random.exponential(4. / (rcoal + rrec), 1)[0]
        assert len(samples) == len(links), "sample/link error"
        if iscoal is True:
            chroms = np.sort(np.random.choice(n, 2, replace=False))
            c1 = chroms[0]
            c2 = chroms[1]

            nodes.add_row(time=t, flags=msprime.NODE_IS_SAMPLE)
            for i in samples[c1]:
                edges.add_row(left=i[0],
                              right=i[1],
                              parent=next_index,
                              child=sample_indexes[c1])
                edges.add_row(left=i[0],
                              right=i[1],
                              parent=next_index,
                              child=sample_indexes[c2])
            newchrom = it.IntervalTree()
            # Merge intervals of the two chromosomes
            # and remove overlaps
            for i in samples[c1]:
                newchrom.append(i)
            for i in samples[c2]:
                newchrom.append(i)
            newchrom.merge_overlaps()
            samples.pop(c2)
            samples.pop(c1)
            samples.append(newchrom)
            sample_indexes.pop(c2)
            sample_indexes.pop(c1)
            sample_indexes.append(next_index)
            next_index += 1
            n -= 1
        else:
            # Pick a chrom proportional to
            # its total size:
            chrom = np.random.choice(len(sample_indexes),
                                     1,
                                     p=links / links.sum())[0]
            mnpos = min(
                [i for j in samples[chrom] for i in j if i is not None])
            mxpos = max(
                [i for j in samples[chrom] for i in j if i is not None])
            pos = np.random.randint(mnpos, mxpos)
            samples[chrom].chop(pos, pos)
            tc = it.IntervalTree([i for i in samples[chrom] if i[0] >= pos])
            samples[chrom].remove_overlap(pos, nsites)
            samples.append(tc)
            sample_indexes.append(next_index)
            next_index += 1
            n += 1

        assert all([len(i) > 0 for i in samples]), "empty IntervalTree"
        assert len(samples) == len(sample_indexes), "sample/sample_index error"
        links = np.array([sumIntervalTree(i) for i in samples], dtype=np.int)
        nlinks = links.sum()
        assert len(samples) == len(links), "sample/link error 2"
    for i in range(len(edges)):
        assert edges[i].parent < len(nodes), "parent error"
        assert edges[i].child < len(nodes), "child error"
    msprime.sort_tables(nodes=nodes, edges=edges)
    return msprime.load_tables(nodes=nodes, edges=edges)
示例#19
0
    def run(self, ngens):
        nodes = msprime.NodeTable()
        edges = msprime.EdgeTable()
        migrations = msprime.MigrationTable()
        sites = msprime.SiteTable()
        mutations = msprime.MutationTable()
        provenances = msprime.ProvenanceTable()
        if self.deep_history:
            # initial population
            init_ts = msprime.simulate(self.N, recombination_rate=1.0)
            init_ts.dump_tables(nodes=nodes, edges=edges)
            nodes.set_columns(time=nodes.time + ngens, flags=nodes.flags)
        else:
            for _ in range(self.N):
                nodes.add_row(time=ngens)

        pop = list(range(self.N))
        for t in range(ngens - 1, -1, -1):
            if self.debug:
                print("t:", t)
                print("pop:", pop)

            dead = [random.random() > self.survival for k in pop]
            # sample these first so that all parents are from the previous gen
            new_parents = [(random.choice(pop), random.choice(pop))
                           for k in range(sum(dead))]
            k = 0
            if self.debug:
                print("Replacing", sum(dead), "individuals.")
            for j in range(self.N):
                if dead[j]:
                    # this is: offspring ID, lparent, rparent, breakpoint
                    offspring = nodes.num_rows
                    nodes.add_row(time=t)
                    lparent, rparent = new_parents[k]
                    k += 1
                    bp = self.random_breakpoint()
                    if self.debug:
                        print("--->", offspring, lparent, rparent, bp)
                    pop[j] = offspring
                    if bp > 0.0:
                        edges.add_row(left=0.0,
                                      right=bp,
                                      parent=lparent,
                                      child=offspring)
                    if bp < 1.0:
                        edges.add_row(left=bp,
                                      right=1.0,
                                      parent=rparent,
                                      child=offspring)

        if self.debug:
            print("Done! Final pop:")
            print(pop)
        flags = [(msprime.NODE_IS_SAMPLE if u in pop else 0)
                 for u in range(nodes.num_rows)]
        nodes.set_columns(time=nodes.time, flags=flags)
        if self.debug:
            print("Done.")
            print("Nodes:")
            print(nodes)
            print("Edges:")
            print(edges)
        return msprime.TableCollection(nodes, edges, migrations, sites,
                                       mutations, provenances)
示例#20
0
def null_tree_sequence():
    return msprime.load_tables(nodes=msprime.NodeTable(),
                               edges=msprime.EdgeTable())
示例#21
0
 def __init__(self, gc_interval=None):
     self.__nodes = msprime.NodeTable()
     self.__edges = msprime.EdgeTable()
     self.gc_interval = gc_interval
     self.last_gc_time = 0
示例#22
0
    def __init__(self,
                 sample_size,
                 num_loci,
                 recombination_rate,
                 migration_matrix,
                 sample_configuration,
                 population_growth_rates,
                 population_sizes,
                 population_growth_rate_changes,
                 population_size_changes,
                 migration_matrix_element_changes,
                 bottlenecks,
                 model='hudson',
                 max_segments=100):
        # Must be a square matrix.
        N = len(migration_matrix)
        assert len(sample_configuration) == N
        assert len(population_growth_rates) == N
        assert len(population_sizes) == N
        for j in range(N):
            assert N == len(migration_matrix[j])
            assert migration_matrix[j][j] == 0
        assert sum(sample_configuration) == sample_size

        self.model = model
        self.n = sample_size
        self.m = num_loci
        self.r = recombination_rate
        self.migration_matrix = migration_matrix
        self.max_segments = max_segments
        self.segment_stack = []
        self.segments = [None for j in range(self.max_segments + 1)]
        for j in range(self.max_segments):
            s = Segment(j + 1)
            self.segments[j + 1] = s
            self.segment_stack.append(s)
        self.P = [Population(id_) for id_ in range(N)]
        self.L = FenwickTree(self.max_segments)
        self.S = bintrees.AVLTree()
        # The output tree sequence.
        self.nodes = msprime.NodeTable()
        self.edges = msprime.EdgeTable()
        self.edge_buffer = []
        for pop_index in range(N):
            sample_size = sample_configuration[pop_index]
            self.P[pop_index].set_start_size(population_sizes[pop_index])
            self.P[pop_index].set_growth_rate(
                population_growth_rates[pop_index], 0)
            for k in range(sample_size):
                j = len(self.nodes)
                x = self.alloc_segment(0, self.m, j, pop_index)
                self.L.set_value(x.index, self.m - 1)
                self.P[pop_index].add(x)
                self.nodes.add_row(flags=msprime.NODE_IS_SAMPLE,
                                   time=0,
                                   population=pop_index)
                j += 1
        self.S[0] = self.n
        self.S[self.m] = -1
        self.t = 0
        self.num_ca_events = 0
        self.num_re_events = 0
        self.modifier_events = [(sys.float_info.max, None, None)]
        for time, pop_id, new_size in population_size_changes:
            self.modifier_events.append(
                (time, self.change_population_size, (int(pop_id), new_size)))
        for time, pop_id, new_rate in population_growth_rate_changes:
            self.modifier_events.append(
                (time, self.change_population_growth_rate, (int(pop_id),
                                                            new_rate, time)))
        for time, pop_i, pop_j, new_rate in migration_matrix_element_changes:
            self.modifier_events.append(
                (time, self.change_migration_matrix_element,
                 (int(pop_i), int(pop_j), new_rate)))
        for time, pop_id, intensity in bottlenecks:
            self.modifier_events.append(
                (time, self.bottleneck_event, (int(pop_id), intensity)))
        self.modifier_events.sort()
示例#23
0
def _load_legacy_hdf5_v2(root, remove_duplicate_positions):
    # Get the coalescence records
    trees_group = root["trees"]
    old_timestamp = datetime.datetime.min.isoformat()
    provenances = msprime.ProvenanceTable()
    provenances.add_row(timestamp=old_timestamp,
                        record=_get_v2_provenance("generate_trees",
                                                  trees_group.attrs))
    num_rows = trees_group["node"].shape[0]
    index = np.arange(num_rows, dtype=int)
    parent = np.zeros(2 * num_rows, dtype=np.int32)
    parent[2 * index] = trees_group["node"]
    parent[2 * index + 1] = trees_group["node"]
    left = np.zeros(2 * num_rows, dtype=np.float64)
    left[2 * index] = trees_group["left"]
    left[2 * index + 1] = trees_group["left"]
    right = np.zeros(2 * num_rows, dtype=np.float64)
    right[2 * index] = trees_group["right"]
    right[2 * index + 1] = trees_group["right"]
    child = np.array(trees_group["children"], dtype=np.int32).flatten()
    edges = msprime.EdgeTable()
    edges.set_columns(left=left, right=right, parent=parent, child=child)

    cr_node = np.array(trees_group["node"], dtype=np.int32)
    num_nodes = max(np.max(child), np.max(cr_node)) + 1
    sample_size = np.min(cr_node)
    flags = np.zeros(num_nodes, dtype=np.uint32)
    population = np.zeros(num_nodes, dtype=np.int32)
    time = np.zeros(num_nodes, dtype=np.float64)
    flags[:sample_size] = msprime.NODE_IS_SAMPLE
    cr_population = np.array(trees_group["population"], dtype=np.int32)
    cr_time = np.array(trees_group["time"])
    time[cr_node] = cr_time
    population[cr_node] = cr_population
    if "samples" in root:
        samples_group = root["samples"]
        population[:sample_size] = samples_group["population"]
        if "time" in samples_group:
            time[:sample_size] = samples_group["time"]
    nodes = msprime.NodeTable()
    nodes.set_columns(flags=flags, population=population, time=time)

    sites = msprime.SiteTable()
    mutations = msprime.MutationTable()
    if "mutations" in root:
        mutations_group = root["mutations"]
        _convert_hdf5_mutations(mutations_group, sites, mutations,
                                remove_duplicate_positions)
        provenances.add_row(timestamp=old_timestamp,
                            record=_get_v2_provenance("generate_mutations",
                                                      mutations_group.attrs))
    provenances.add_row(_get_upgrade_provenance(root))
    msprime.sort_tables(nodes=nodes,
                        edges=edges,
                        sites=sites,
                        mutations=mutations)
    return msprime.load_tables(nodes=nodes,
                               edges=edges,
                               sites=sites,
                               mutations=mutations,
                               provenances=provenances)
示例#24
0
    def get_tree_sequence(self, rescale_positions=True, all_sites=False):
        """
        Returns the current state of the build tree sequence. All samples and
        ancestors will have the sample node flag set.
        """
        # TODO Change the API here to ask whether we want a final tree sequence
        # or not. In the latter case we also need to translate the ancestral
        # and derived states to the input values.
        tsb = self.tree_sequence_builder
        flags, time = tsb.dump_nodes()
        nodes = msprime.NodeTable()
        nodes.set_columns(flags=flags, time=time)

        left, right, parent, child = tsb.dump_edges()
        if rescale_positions:
            position = self.sample_data.position[:]
            sequence_length = self.sample_data.sequence_length
            if sequence_length is None or sequence_length < position[-1]:
                sequence_length = position[-1] + 1
            # Subset down to the variants.
            position = position[self.sample_data.variant_site[:]]
            x = np.hstack([position, [sequence_length]])
            x[0] = 0
            left = x[left]
            right = x[right]
        else:
            position = np.arange(tsb.num_sites)
            sequence_length = max(1, tsb.num_sites)

        edges = msprime.EdgeTable()
        edges.set_columns(left=left, right=right, parent=parent, child=child)

        sites = msprime.SiteTable()
        sites.set_columns(
            position=position,
            ancestral_state=np.zeros(tsb.num_sites, dtype=np.int8) + ord('0'),
            ancestral_state_offset=np.arange(tsb.num_sites + 1,
                                             dtype=np.uint32))
        mutations = msprime.MutationTable()
        site = np.zeros(tsb.num_mutations, dtype=np.int32)
        node = np.zeros(tsb.num_mutations, dtype=np.int32)
        parent = np.zeros(tsb.num_mutations, dtype=np.int32)
        derived_state = np.zeros(tsb.num_mutations, dtype=np.int8)
        site, node, derived_state, parent = tsb.dump_mutations()
        derived_state += ord('0')
        mutations.set_columns(site=site,
                              node=node,
                              derived_state=derived_state,
                              derived_state_offset=np.arange(
                                  tsb.num_mutations + 1, dtype=np.uint32),
                              parent=parent)
        if all_sites:
            # Append the sites and mutations for each singleton.
            num_singletons = self.sample_data.num_singleton_sites
            singleton_site = self.sample_data.singleton_site[:]
            singleton_sample = self.sample_data.singleton_sample[:]
            pos = self.sample_data.position[:]
            new_sites = np.arange(len(sites),
                                  len(sites) + num_singletons,
                                  dtype=np.int32)
            sites.append_columns(
                position=pos[singleton_site],
                ancestral_state=np.zeros(num_singletons, dtype=np.int8) +
                ord('0'),
                ancestral_state_offset=np.arange(num_singletons + 1,
                                                 dtype=np.uint32))
            mutations.append_columns(
                site=new_sites,
                node=self.sample_ids[singleton_sample],
                derived_state=np.zeros(num_singletons, dtype=np.int8) +
                ord('1'),
                derived_state_offset=np.arange(num_singletons + 1,
                                               dtype=np.uint32))
            # Get the invariant sites
            num_invariants = self.sample_data.num_invariant_sites
            invariant_site = self.sample_data.invariant_site[:]
            sites.append_columns(
                position=pos[invariant_site],
                ancestral_state=np.zeros(num_invariants, dtype=np.int8) +
                ord('0'),
                ancestral_state_offset=np.arange(num_invariants + 1,
                                                 dtype=np.uint32))

        msprime.sort_tables(nodes, edges, sites=sites, mutations=mutations)
        return msprime.load_tables(nodes=nodes,
                                   edges=edges,
                                   sites=sites,
                                   mutations=mutations,
                                   sequence_length=sequence_length)
示例#25
0
def run_simplify_num_edges_benchmark(args):
    ts = msprime.load(args.file)
    np.random.seed(1)
    print("num_nodes = ", ts.num_nodes)
    print("num_edges = ", ts.num_edges)
    num_slices = 10

    tables = ts.dump_tables()
    nodes = tables.nodes
    edges = tables.edges

    node_time = nodes.time
    left = edges.left
    right = edges.right
    parent = edges.parent
    child = edges.child

    size = left.nbytes + right.nbytes + parent.nbytes + child.nbytes
    print("Total edge size = ", size / 1024**3, "GiB")
    sample_sizes = [10, 100, 1000]
    num_sample_sizes = len(sample_sizes)

    num_edges = np.zeros(num_slices * num_sample_sizes)
    simplify_time = np.zeros(num_slices * num_sample_sizes)
    sample_size = np.zeros(num_slices * num_sample_sizes)
    slice_size = ts.num_edges // num_slices

    j = 0
    for N in sample_sizes:
        for start in range(ts.num_edges - slice_size, 0, -slice_size):
            max_node = np.max(child[start:])
            samples = np.arange(max_node - N, max_node, dtype=np.int32)
            subset_nodes = msprime.NodeTable()
            subset_nodes.set_columns(time=node_time[:max_node + 1],
                                     flags=np.ones(max_node + 1,
                                                   dtype=np.uint32))
            subset_edges = msprime.EdgeTable()
            subset_edges.set_columns(left=left[start:],
                                     right=right[start:],
                                     parent=parent[start:],
                                     child=child[start:])
            before = time.process_time()
            msprime.simplify_tables(samples=samples,
                                    nodes=subset_nodes,
                                    edges=subset_edges)
            duration = time.process_time() - before
            num_edges[j] = ts.num_edges - start
            simplify_time[j] = duration
            sample_size[j] = N
            print(N, num_edges[j], duration, num_edges[j] / duration,
                  "per second")
            j += 1

    df = pd.DataFrame({
        "sample_size": sample_size,
        "num_edges": num_edges,
        "time": simplify_time
    })
    df.to_csv("data/simplify_num_edges.dat")

    for N in sample_sizes:
        index = sample_size == N

        plt.plot(num_edges[index], simplify_time[index], marker="o")
        plt.xlabel("num edges")
        plt.ylabel("Time to simplify (s)")
        plt.savefig("simplify_num_edges.png")