def generate_samples_empirical(ts):

    """
    Generate a samples file from a simulated ts based on an empirically estimated error matrix
    Reject any variants that result in a fixed column. 
    """
    assert ts.num_sites != 0
    sample_data = tsinfer.SampleData(sequence_length=ts.sequence_length)

    for v in ts.variants():
        #Record the allele frequency
        m = v.genotypes.shape[0]
        frequency = np.sum(v.genotypes) / m

        #Find closest row in error matrix file
        closest_freq = error_matrix.iloc[(error_matrix['freq']-frequency).abs().argsort()[:1]]

        #make new genotypes with error
        # Reject any columns that have no 1s or no zeros.
        # Unless the original also has them, as occasionally we have
        # some sims (e.g. under selection) where a variant is fixed
        genotypes = make_errors_genotype_model(v.genotypes,closest_freq)
        
        sample_data.add_site(
            position=v.site.position, alleles=v.alleles,
            genotypes=genotypes)

    sample_data.finalise()
    return sample_data
Esempio n. 2
0
def read_samples(filename):
    ''' Takes .csv file with genotype matrix and converts it to tsinfer SampleData format
        Args: filename (located in /data/samples)'''
    with open('../data/samples/' + filename + '.csv',
              encoding='utf-8',
              newline='') as file:
        data = list(csv.reader(file))
        length = len(data)
        sample_data = tsinfer.SampleData(sequence_length=length,
                                         num_flush_threads=2)
        for i in range(length):
            line = ''.join(data[i])
            genotypes = []
            alleles = []
            al1 = line[0]
            for c in line:
                if c != al1:
                    genotypes.append(1)
                    al2 = c
                else:
                    genotypes.append(0)
            alleles.append(al1)
            alleles.append(al2)
            sample_data.add_site(i, genotypes, alleles)
    return sample_data
Esempio n. 3
0
 def get_random_data_example(self, position, num_samples, seed=100):
     np.random.seed(seed)
     G = np.random.randint(2, size=(position.shape[0], num_samples)).astype(np.int8)
     with tsinfer.SampleData() as sample_data:
         for j, x in enumerate(position):
             sample_data.add_site(x, G[j])
     return sample_data
Esempio n. 4
0
 def test_sample_data(self, small_ts_fixture):
     with tsinfer.SampleData() as sample_data:
         for var in small_ts_fixture.variants():
             sample_data.add_site(var.site.position,
                                  genotypes=var.genotypes)
         sample_data.record_provenance("test", arg1=1, arg2=2)
     self.validate_file(sample_data)
Esempio n. 5
0
def tutorial_samples():
    import tqdm
    import msprime
    import tsinfer

    ts = msprime.simulate(
        sample_size=10000,
        Ne=10**4,
        recombination_rate=1e-8,
        mutation_rate=1e-8,
        length=10 * 10**6,
        random_seed=42,
    )
    ts.dump("tmp__NOBACKUP__/simulation-source.trees")
    print("simulation done:", ts.num_trees, "trees and", ts.num_sites, "sites")

    progress = tqdm.tqdm(total=ts.num_sites)
    with tsinfer.SampleData(
            path="tmp__NOBACKUP__/simulation.samples",
            sequence_length=ts.sequence_length,
            num_flush_threads=2,
    ) as sample_data:
        for var in ts.variants():
            sample_data.add_site(var.site.position, var.genotypes, var.alleles)
            progress.update()
    progress.close()
Esempio n. 6
0
def main():
    """Run main function."""
    args = parse_args(sys.argv[1:])
    # =========================================================================
    #  Gather args
    # =========================================================================
    vcf_path = args.vcf
    outfile = args.outfile
    threads = args.threads
    label_by = args.pops_header
    meta = pd.read_csv(args.meta, sep="\t", index_col="sampleID", dtype=object)
    # =========================================================================
    #  Main executions
    # =========================================================================

    vcf = cyvcf2.VCF(vcf_path)
    with tsinfer.SampleData(path=f"{outfile}.samples",
                            sequence_length=chrom_len(vcf),
                            num_flush_threads=threads,
                            max_file_size=2**37) as samples:

        add_metadata(vcf, samples, meta, label_by)
        add_diploid_sites(vcf, samples)

    print(
        f"Sample file created for {samples.num_samples} samples ({samples.num_individuals}) with {samples.num_sites} variable sites.",
        flush=True)
Esempio n. 7
0
 def setUp(self):
     self.tempdir = tempfile.TemporaryDirectory(prefix="tsinfer_cli_test")
     self.sample_file = str(
         pathlib.Path(self.tempdir.name, "input-data.samples"))
     self.ancestor_file = str(
         pathlib.Path(self.tempdir.name, "input-data.ancestors"))
     self.ancestor_trees = str(
         pathlib.Path(self.tempdir.name, "input-data.ancestors.trees"))
     self.output_trees = str(
         pathlib.Path(self.tempdir.name, "input-data.trees"))
     self.input_ts = msprime.simulate(10,
                                      mutation_rate=10,
                                      recombination_rate=10,
                                      random_seed=10)
     sample_data = tsinfer.SampleData(
         sequence_length=self.input_ts.sequence_length,
         path=self.sample_file)
     for var in self.input_ts.variants():
         sample_data.add_site(var.site.position, var.genotypes, var.alleles)
     sample_data.finalise()
     tsinfer.generate_ancestors(sample_data,
                                path=self.ancestor_file,
                                chunk_size=10)
     ancestor_data = tsinfer.load(self.ancestor_file)
     ancestors_ts = tsinfer.match_ancestors(sample_data, ancestor_data)
     ancestors_ts.dump(self.ancestor_trees)
     ts = tsinfer.match_samples(sample_data, ancestors_ts)
     ts.dump(self.output_trees)
     sample_data.close()
Esempio n. 8
0
 def get_random_data_example(self, num_sites, num_samples, seed=100):
     np.random.seed(seed)
     G = np.random.randint(2, size=(num_sites, num_samples)).astype(np.int8)
     with tsinfer.SampleData() as sample_data:
         for j in range(num_sites):
             sample_data.add_site(j, G[j])
     return sample_data
Esempio n. 9
0
 def test_sample_data(self):
     ts = msprime.simulate(10, mutation_rate=1, random_seed=1)
     self.assertGreater(ts.num_sites, 1)
     with tsinfer.SampleData() as sample_data:
         for var in ts.variants():
             sample_data.add_site(var.site.position, genotypes=var.genotypes)
         sample_data.record_provenance("test", arg1=1, arg2=2)
     self.validate_file(sample_data)
Esempio n. 10
0
def generate_samples(
    ts, fn, aa_error="0", seq_error="0", empirical_seq_err_name=""):
    """
    Generate a samples file from a simulated ts. We can pass an integer or a 
    matrix as the seq_error. If a matrix, specify a name for it in empirical_seq_err
    """
    record_rate = logging.getLogger().isEnabledFor(logging.INFO)
    n_variants = bits_flipped = bad_ancestors = 0
    assert ts.num_sites != 0
    fn += ".samples"
    sample_data = tsinfer.SampleData(path=fn, sequence_length=ts.sequence_length)
    
    # Setup the sequencing error used. Empirical error should be a matrix not a float
    if not empirical_seq_err_name:
        seq_error = float(seq_error) if seq_error else 0
        if seq_error == 0:
            record_rate = False # no point recording the achieved error rate
            sequencing_error = make_no_errors
        else:
            logging.info("Adding genotyping error: {} used in file {}".format(
                seq_error, fn))
            sequencing_error = make_seq_errors_simple
    else:
        logging.info("Adding empirical genotyping error: {} used in file {}".format(
            empirical_seq_err_name, fn))
        sequencing_error = make_seq_errors_genotype_model
    # Setup the ancestral state error used
    aa_error = float(aa_error) if aa_error else 0
    aa_error_by_site = np.zeros(ts.num_sites, dtype=np.bool)
    if aa_error > 0:
        assert aa_error <= 1
        n_bad_sites = round(aa_error*ts.num_sites)
        logging.info("Adding ancestral allele polarity error: {}% ({}/{} sites) used in file {}"
            .format(aa_error * 100, n_bad_sites, ts.num_sites, fn))
        # This gives *exactly* a proportion aa_error or bad sites
        # NB - to to this probabilitistically, use np.binomial(1, e, ts.num_sites)
        aa_error_by_site[0:n_bad_sites] = True
        np.random.shuffle(aa_error_by_site)
        assert sum(aa_error_by_site) == n_bad_sites
    for ancestral_allele_error, v in zip(aa_error_by_site, ts.variants()):
        n_variants += 1    
        genotypes = sequencing_error(v.genotypes, seq_error)
        if record_rate:
            bits_flipped += np.sum(np.logical_xor(genotypes, v.genotypes))
            bad_ancestors += ancestral_allele_error
        if ancestral_allele_error:
            sample_data.add_site(
                position=v.site.position, alleles=v.alleles, genotypes=1-genotypes)
        else:
            sample_data.add_site(
                position=v.site.position, alleles=v.alleles, genotypes=genotypes)
    if record_rate:
        logging.info(" actual error rate = {} over {} sites before {} ancestors flipped"
            .format(bits_flipped/(n_variants*ts.sample_size), n_variants, bad_ancestors))

    sample_data.finalise()
    return sample_data
Esempio n. 11
0
 def test_inferred_random_data(self):
     np.random.seed(10)
     num_sites = 40
     num_samples = 8
     G = np.random.randint(2, size=(num_sites, num_samples)).astype(np.int8)
     with tsinfer.SampleData() as sample_data:
         for j in range(num_sites):
             sample_data.add_site(j, G[j])
     ts = tsinfer.infer(sample_data)
     self.verify(ts)
def infer_from_msprime(simulation):
    ''' Given msprime simulation results, obtains the corresponding inferred
        tree sequence using tsinfer
        Args: result - msprime output
    '''

    with tsinfer.SampleData (sequence_length=simulation.sequence_length, num_flush_threads=2) as sample_data:
        for var in simulation.variants ():
            sample_data.add_site ( var.site.position, var.genotypes, var.alleles )
    inferred_ts = tsinfer.infer (sample_data)
    return inferred_ts
def generate_samples(ts):
    """
    Generate a samples file from a simulated ts
    Samples may have bits flipped with a specified probability.
    (reject any variants that result in a fixed column)
    """

    assert ts.num_sites != 0
    
    sample_data = tsinfer.SampleData(sequence_length=ts.sequence_length)
    for v in ts.variants():
        sample_data.add_site(
            position=v.site.position, alleles=v.alleles,
            genotypes=v.genotypes)

    sample_data.finalise()
    return sample_data
Esempio n. 14
0
def build_profile_inputs(n, num_megabases):
    L = num_megabases * 10**6
    input_file = "tmp__NOBACKUP__/profile-n={}-m={}.input.trees".format(
        n, num_megabases)
    if os.path.exists(input_file):
        ts = msprime.load(input_file)
    else:
        ts = msprime.simulate(
            n,
            length=L,
            Ne=10**4,
            recombination_rate=1e-8,
            mutation_rate=1e-8,
            random_seed=10,
        )
        print(
            "Ran simulation: n = ",
            n,
            " num_sites = ",
            ts.num_sites,
            "num_trees =",
            ts.num_trees,
        )
        ts.dump(input_file)
    filename = "tmp__NOBACKUP__/profile-n={}-m={}.samples".format(
        n, num_megabases)
    if os.path.exists(filename):
        os.unlink(filename)
    # daiquiri.setup(level="DEBUG")
    with tsinfer.SampleData(sequence_length=ts.sequence_length,
                            path=filename,
                            num_flush_threads=4) as sample_data:
        # progress_monitor = tqdm.tqdm(total=ts.num_samples)
        # for j in range(ts.num_samples):
        #     sample_data.add_sample(metadata={"name": "sample_{}".format(j)})
        #     progress_monitor.update()
        # progress_monitor.close()
        progress_monitor = tqdm.tqdm(total=ts.num_sites)
        for variant in ts.variants():
            sample_data.add_site(variant.site.position, variant.genotypes)
            progress_monitor.update()
        progress_monitor.close()

    print(sample_data)
Esempio n. 15
0
def convert(
        vcf_file, pedigree_file, output_file, max_variants=None, show_progress=False):

    if max_variants is None:
        max_variants = 2**32  # Arbitrary, but > defined max for VCF

    with tsinfer.SampleData(path=output_file, num_flush_threads=2) as sample_data:
        pop_id_map = add_populations(sample_data)

        vcf = cyvcf2.VCF(vcf_file)
        individual_names = list(vcf.samples)
        vcf.close()

        with open(pedigree_file, "r") as ped_file:
            add_samples(ped_file, pop_id_map, individual_names, sample_data)

        for index, site in enumerate(variants(vcf_file, show_progress)):
            sample_data.add_site(
                position=site.position, genotypes=site.genotypes,
                alleles=site.alleles, metadata=site.metadata)
            if index == max_variants:
                break
        sample_data.record_provenance(command=sys.argv[0], args=sys.argv[1:])
Esempio n. 16
0
import os
import sys

import msprime

sys.path.insert(0, os.path.abspath(".."))
import tsinfer  # noqa


ts = msprime.simulate(5, mutation_rate=0.7, random_seed=10)
tree = ts.first()
print(ts.num_sites)
print(tree.draw(format="unicode"))

with tsinfer.SampleData(path="toy.samples") as sample_data:
    sample_data.add_site(10, [0, 1, 0, 0, 0], ["A", "T"])
    sample_data.add_site(12, [0, 0, 0, 1, 1], ["G", "C"])
    sample_data.add_site(23, [0, 1, 1, 0, 0], ["C", "A"])
    sample_data.add_site(37, [0, 1, 1, 0, 0], ["G", "C"])
    sample_data.add_site(40, [0, 0, 0, 1, 1], ["A", "C"])
    sample_data.add_site(50, [0, 1, 0, 0, 0], ["T", "G"])

print(sample_data)

inferred_ts = tsinfer.infer(sample_data)
for tree in inferred_ts.trees():
    print(tree.draw(format="unicode"))

for sample_id, h in enumerate(inferred_ts.haplotypes()):
    print(sample_id, h, sep="\t")
Esempio n. 17
0
if True:
    ts = msprime.simulate(
        sample_size=10000,
        Ne=10**4,
        recombination_rate=1e-8,
        mutation_rate=1e-8,
        length=10 * 10**6,
        random_seed=42,
    )
    ts.dump("simulation-source.trees")
    print("Simulation done:", ts.num_trees, "trees and", ts.num_sites)

    with tsinfer.SampleData(
            sequence_length=ts.sequence_length,
            path="simulation.samples",
            num_flush_threads=2,
    ) as samples:
        for var in tqdm.tqdm(ts.variants(), total=ts.num_sites):
            samples.add_site(var.site.position, var.genotypes, var.alleles)

else:
    source = msprime.load("simulation-source.trees")
    inferred = msprime.load("simulation.trees")

    subset = range(0, 6)
    source_subset = source.simplify(subset)
    inferred_subset = inferred.simplify(subset)

    tree = source_subset.first()
    print("True tree: interval=", tree.interval)
Esempio n. 18
0
def main():
    parser = argparse.ArgumentParser(
        description="Script to convert VCF files into tsinfer input.")
    parser.add_argument("source",
                        choices=["1kg", "sgdp", "ukbb"],
                        help="The source of the input data.")
    parser.add_argument("data_file", help="The input data file pattern.")
    parser.add_argument("ancestral_states_file",
                        help="A vcf file containing ancestral allele states. ")
    parser.add_argument("output_file", help="The tsinfer output file")
    parser.add_argument(
        "-m",
        "--metadata_file",
        default=None,
        help="The metadata file containing population and sample data")
    parser.add_argument("-n",
                        "--max-variants",
                        default=None,
                        type=int,
                        help="Keep only the first n variants")
    parser.add_argument(
        "-p",
        "--progress",
        action="store_true",
        help="Show progress bars and output extra information when done")
    parser.add_argument(
        "--ancestral-states-url",
        default=None,
        help="The source of ancestral state information for provenance.")
    parser.add_argument("--reference-name",
                        default=None,
                        help="The name of the reference for provenance.")

    args = parser.parse_args()

    git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"])
    git_provenance = {
        "repo":
        "[email protected]:mcveanlab/treeseq-inference.git",
        "hash":
        git_hash.decode().strip(),
        "dir":
        "human-data",
        "notes:":
        ("Use the Makefile to download and process the upstream data files")
    }
    data_provenance = {
        "ancestral_states_url": args.ancestral_states_url,
        "reference_name": args.reference_name
    }

    # Get the ancestral states.
    fasta = pysam.FastaFile(args.ancestral_states_file)
    # NB! We put in an extra character at the start to convert to 1 based coords.
    ancestral_states = "X" + fasta.fetch(reference=fasta.references[0])
    # The largest possible site position is len(ancestral_states). Positions must
    # be strictly less than sequence_length, so we add 1.
    sequence_length = len(ancestral_states) + 1

    converter_class = {
        "1kg": ThousandGenomesConverter,
        "sgdp": SgdpConverter,
        "ukbb": UkbbConverter
    }

    try:
        with tsinfer.SampleData(path=args.output_file,
                                num_flush_threads=2,
                                sequence_length=sequence_length) as samples:
            converter = converter_class[args.source](args.data_file,
                                                     ancestral_states, samples)
            converter.process_metadata(args.metadata_file, args.progress)
            converter.process_sites(args.progress, args.max_variants)
            samples.record_provenance(command=sys.argv[0],
                                      args=sys.argv[1:],
                                      git=git_provenance,
                                      data=data_provenance)
    except Exception as e:
        os.unlink(args.output_file)
        raise e
    print(samples)
Esempio n. 19
0
#ancient_sample_indices = [i for i, e in enumerate(sample_list) if e.time != 0]

#def simulate_ts():
#    modern_samples = [msprime.Sample(population=0, time=0) for x in range(90)]
#    ancient_samples = [msprime.Sample(population=0, time=ANCIENT_TIME) for x in range(10)]
#    samples = modern_samples + ancient_samples
#    return(msprime.simulate(samples=samples, mutation_rate=1e-8,
#                            recombination_rate=1e-8, length=1e4,
#                            Ne=10000))
#
#ts = simulate_ts()

sample_data = tsinfer.formats.SampleData.from_tree_sequence(ts)

# Remove non modern samples from tree sequence
modern_samples = tsinfer.SampleData(path="modern_only_hiv.samples",
                                    sequence_length=ts.sequence_length)

for individual in range(len(modern_sample_indices)):
    modern_samples.add_individual(ploidy=1, metadata={})

for v in ts.variants():
    modern_samples.add_site(position=v.site.position,
                            alleles=v.alleles,
                            genotypes=v.genotypes[modern_sample_indices])
modern_samples.finalise()

# make ancient sample data files
ancient_samples = tsinfer.SampleData(path="ancient_only_hiv.samples",
                                     sequence_length=ts.sequence_length)
for individual in range(len(ancient_sample_indices)):
    ancient_samples.add_individual(ploidy=1, metadata={})
def make_sampledata(args):
    if isinstance(args, tuple):
        vcf_subset = args[2]
        args[0].output_file = str(args[1])
        args = args[0]
    else:
        vcf_subset = None
    try:
        git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"])
        git_provenance = {
            "repo":
            "[email protected]:mcveanlab/treeseq-inference.git",
            "hash":
            git_hash.decode().strip(),
            "dir":
            "human-data",
            "notes:":
            ("Use the Makefile to download and process the upstream data files"
             ),
        }
    except FileNotFoundError:
        git_hash = "Git unavailable"
        git_provenance = "Git unavailable"
    data_provenance = {
        "ancestral_states_url": args.ancestral_states_url,
        "reference_name": args.reference_name,
    }

    # Get the ancestral states.
    fasta = pysam.FastaFile(args.ancestral_states_file)
    # NB! We put in an extra character at the start to convert to 1 based coords.
    ancestral_states = "X" + fasta.fetch(reference=fasta.references[0])
    # The largest possible site position is len(ancestral_states). Positions must
    # be strictly less than sequence_length, so we add 1.
    sequence_length = len(ancestral_states) + 1

    converter_class = {
        "1kg": ThousandGenomesConverter,
        "sgdp": SgdpConverter,
        "hgdp": HgdpConverter,
        "max-planck": MaxPlanckConverter,
        "afanasievo": AfanasievoConverter,
        "1240k": ReichConverter,
    }
    try:
        with tsinfer.SampleData(path=args.output_file,
                                num_flush_threads=1,
                                sequence_length=sequence_length) as samples:
            converter = converter_class[args.source](args.data_file,
                                                     ancestral_states, samples,
                                                     args.target_samples)
            if args.metadata_file:
                converter.process_metadata(args.metadata_file, args.progress)
            else:
                converter.process_metadata(args.progress)
            if vcf_subset is not None:
                report = converter.process_sites(
                    vcf_subset=vcf_subset,
                    show_progress=args.progress,
                    max_sites=args.max_variants,
                )
            else:
                report = converter.process_sites(show_progress=args.progress,
                                                 max_sites=args.max_variants)
            samples.record_provenance(
                command=sys.argv[0],
                args=sys.argv[1:],
                git=git_provenance,
                data=data_provenance,
            )
            assert np.all(np.diff(samples.sites_position[:]) > 0)
    except Exception as e:
        os.unlink(args.output_file)
        if report["num_sites"] == 0:
            return report
        raise e
    if report["num_sites"] == 0:
        os.unlink(args.output_file)
    return report
Esempio n. 21
0
def run_combine_ukbb_1kg(args):
    ukbb_samples_file = "ukbb_{}.samples".format(args.chromosome)
    tg_ancestors_ts_file = "1kg_{}.trees".format(args.chromosome)
    ancestors_ts_file = "1kg_ukbb_{}.ancestors.trees".format(args.chromosome)
    samples_file = "1kg_ukbb_{}.samples".format(args.chromosome)

    ukbb_samples = tsinfer.load(ukbb_samples_file)
    tg_ancestors_ts = tskit.load(tg_ancestors_ts_file)
    print("Loaded ts:", tg_ancestors_ts.num_nodes, tg_ancestors_ts.num_edges)

    # Subset the sites down to the UKBB sites.
    tables = tg_ancestors_ts.dump_tables()
    ukbb_sites = set(ukbb_samples.sites_position[:])
    ancestors_sites = set(tables.sites.position[:])
    intersecting_sites = ancestors_sites & ukbb_sites

    print("Intersecting sites = ", len(intersecting_sites))
    tables.sites.clear()
    tables.mutations.clear()
    for site in tg_ancestors_ts.sites():
        if site.position in intersecting_sites:
            # Sites must be 0/1 for the ancestors ts.
            site_id = tables.sites.add_row(position=site.position,
                                           ancestral_state="0")
            assert len(site.mutations) == 1
            mutation = site.mutations[0]
            tables.mutations.add_row(site=site_id,
                                     node=mutation.node,
                                     derived_state="1")

    # Reduce this to the site topology now to make things as quick as possible.
    tables.simplify(reduce_to_site_topology=True, filter_sites=False)
    reduced_ts = tables.tree_sequence()
    # Rewrite the nodes so that 0 is one older than all the other nodes.
    nodes = tables.nodes.copy()
    tables.nodes.clear()
    tables.nodes.add_row(flags=1, time=np.max(nodes.time) + 2)
    tables.nodes.append_columns(
        flags=np.bitwise_or(nodes.flags, 1),  # Everything is a sample.
        time=nodes.time + 1,  # Make sure that all times are > 0
        population=nodes.population,
        individual=nodes.individual,
        metadata=nodes.metadata,
        metadata_offset=nodes.metadata_offset)
    # Add one to all node references to account for this.
    tables.edges.set_columns(left=tables.edges.left,
                             right=tables.edges.right,
                             parent=tables.edges.parent + 1,
                             child=tables.edges.child + 1)
    tables.mutations.set_columns(
        node=tables.mutations.node + 1,
        site=tables.mutations.site,
        parent=tables.mutations.parent,
        derived_state=tables.mutations.derived_state,
        derived_state_offset=tables.mutations.derived_state_offset,
        metadata=tables.mutations.metadata,
        metadata_offset=tables.mutations.metadata_offset)

    trees = reduced_ts.trees()
    tree = next(trees)
    left = 0
    root = tree.root
    for tree in trees:
        if tree.root != root:
            tables.edges.add_row(left, tree.interval[0], 0, root + 1)
            root = tree.root
            left = tree.interval[0]
    tables.edges.add_row(left, reduced_ts.sequence_length, 0, root + 1)
    tables.sort()
    ancestors_ts = tables.tree_sequence()
    print("Writing ancestors_ts")
    ancestors_ts.dump(ancestors_ts_file)

    # Now create a new samples file to get rid of the missing sites.
    git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"])
    git_provenance = {
        "repo":
        "[email protected]:mcveanlab/treeseq-inference.git",
        "hash":
        git_hash.decode().strip(),
        "dir":
        "human-data",
        "notes:":
        ("Use the Makefile to download and process the upstream data files")
    }

    n = args.num_individuals
    if n is None:
        n = ukbb_samples.num_individuals
    with tsinfer.SampleData(
            path=samples_file,
            num_flush_threads=4,
            sequence_length=ukbb_samples.sequence_length) as samples:

        iterator = tqdm.tqdm(itertools.islice(
            tqdm.tqdm(ukbb_samples.individuals()), n),
                             total=n)
        for ind in iterator:
            samples.add_individual(ploidy=2,
                                   location=ind.location,
                                   metadata=ind.metadata)

        for variant in tqdm.tqdm(ukbb_samples.variants(),
                                 total=ukbb_samples.num_sites):
            if variant.site.position in intersecting_sites:
                samples.add_site(position=variant.site.position,
                                 alleles=variant.alleles,
                                 genotypes=variant.genotypes[:2 * n],
                                 metadata=variant.site.metadata)

        for timestamp, record in ukbb_samples.provenances():
            samples.add_provenance(timestamp, record)
        samples.record_provenance(command=sys.argv[0],
                                  args=sys.argv[1:],
                                  git=git_provenance)

    print(samples)
Esempio n. 22
0
def tsinfer_dev(
    n,
    L,
    seed,
    num_threads=1,
    recombination_rate=1e-8,
    error_rate=0,
    engine="C",
    log_level="WARNING",
    precision=None,
    debug=True,
    progress=False,
    path_compression=True,
):

    np.random.seed(seed)
    random.seed(seed)
    L_megabases = int(L * 10**6)

    # daiquiri.setup(level=log_level)

    source_ts = msprime.simulate(
        n,
        Ne=10**4,
        length=L_megabases,
        recombination_rate=recombination_rate,
        mutation_rate=1e-8,
        random_seed=seed,
    )
    if debug:
        print("num_sites = ", source_ts.num_sites)
    assert source_ts.num_sites > 0

    # ts = msprime.mutate(ts, rate=1e-8, random_seed=seed,
    #         model=msprime.InfiniteSites(msprime.NUCLEOTIDES))

    # samples = tsinfer.SampleData.from_tree_sequence(ts)

    with tsinfer.SampleData(
            sequence_length=source_ts.sequence_length) as samples:
        for var in source_ts.variants():
            # var.genotypes[var.site.id % source_ts.num_samples] = tskit.MISSING_DATA
            samples.add_site(var.site.position, var.genotypes, var.alleles)

    print(samples)
    # for variant in samples.variants():
    #     print(variant)

    rho = recombination_rate
    mmr = 100  # 1e-2

    #     num_alleles = samples.num_alleles(inference_sites=True)
    #     num_sites = samples.num_inference_sites
    #     with tsinfer.AncestorData(samples) as ancestor_data:
    #         t = np.sum(num_alleles) + 1
    #         for j in range(num_sites):
    #             for allele in range(num_alleles[j]):
    #                 ancestor_data.add_ancestor(j, j + 1, t, [j], [allele])
    #                 t -= 1

    ancestor_data = tsinfer.generate_ancestors(samples,
                                               engine=engine,
                                               num_threads=num_threads)
    print(ancestor_data)

    ancestors_ts = tsinfer.match_ancestors(
        samples,
        ancestor_data,
        engine=engine,
        path_compression=True,
        extended_checks=False,
        precision=precision,
        recombination_rate=rho,
        mismatch_ratio=mmr,
    )
    # print(ancestors_ts.tables)

    # print("ancestors ts")
    # for tree in ancestors_ts.trees():
    #     print(tree.draw_text())
    #     for site in tree.sites():
    #         if len(site.mutations) > 1:
    #             print(site.id)
    #             for mutation in site.mutations:
    #                 print("\t", mutation.node, mutation.derived_state)

    # for var in ancestors_ts.variants():
    #     print(var.genotypes)

    # print(ancestors_ts.tables)

    # ancestors_ts = tsinfer.augment_ancestors(samples, ancestors_ts,
    #         [5, 6, 7], engine=engine)

    ts = tsinfer.match_samples(
        samples,
        ancestors_ts,
        recombination_rate=rho,
        mismatch_ratio=mmr,
        path_compression=False,
        engine=engine,
        precision=precision,
        simplify=False,
    )
    # print(ts.tables)

    #     for var1, var2, var3 in zip(
    #         source_ts.variants(), ts.variants(), samples.variants()
    #     ):
    #         if np.any(var1.genotypes != var2.genotypes):
    #             print("mismatch at ", var1.site.id)
    #             print(var1.genotypes)
    #             print(var2.genotypes)
    #             print(var3.genotypes)

    print("num_edges = ", ts.num_edges)

    # # print(ts.draw_text())
    # for tree in ts.trees():
    #     print(tree.draw_text())
    #     for site in tree.sites():
    #         if len(site.mutations) > 1:
    #             print(site.id)
    #             for mutation in site.mutations:
    #                 print("\t", mutation.node, mutation.derived_state)

    # # print(ts.tables.edges)
    # print(ts.dump_tables())

    # simplified = ts.simplify()
    # print("edges before = ", simplified.num_edges)

    # new_ancestors_ts = insert_srb_ancestors(ts)
    # ts = tsinfer.match_samples(samples, new_ancestors_ts,
    #         path_compression=False, engine=engine,
    #         simplify=True)

    #     for tree in ts.trees():
    #         print(tree.interval)
    #         print(tree.draw(format="unicode"))

    # print(ts.tables.edges)
    # for tree in ts.trees():
    #     print(tree.draw(format="unicode"))

    tsinfer.verify(samples, ts)