Exemple #1
0
    def _link_populations(self, popdict=None, popmins=None):
        """
        Creates self.populations dictionary to save mappings of individuals to
        populations/sites, and checks that individual names match with Samples.
        The self.populations dict keys are pop names and the values are lists
        of length 2. The first element is the min number of samples per pop
        for final filtering of loci, and the second element is the list of
        samples per pop.

        Population assigments are used for heirarchical clustering, for
        generating summary stats, and for outputing some file types (.treemix
        for example). Internally stored as a dictionary.

        Note
        ----
        By default a File is read in from `pop_assign_file` with one individual
        per line and space separated pairs of ind pop:

            ind1 pop1
            ind2 pop2
            ind3 pop3
            etc...

        Parameters
        ----------
        popdict : dict
            When using the API it may be easier to simply create a dictionary
            to pass in as an argument instead of reading from an input file.
            This can be done with the `popdict` argument like below:

            pops = {'pop1': ['ind1', 'ind2', 'ind3'], 'pop2': ['ind4', 'ind5']}
            [Assembly]._link_populations(popdict=pops).

        popmins : dict
            If you want to apply a minsamples filter based on populations
            you can add a popmins dictionary. This indicates the number of 
            samples in each population that must be present in a locus for 
            the locus to be retained. Example:

            popmins = {'pop1': 3, 'pop2': 2}

        """
        if not popdict:
            ## glob it in case of fuzzy matching
            popfile = glob.glob(self.params.pop_assign_file)[0]
            if not os.path.exists(popfile):
                raise IPyradError(
                    "Population assignment file not found: {}".format(
                        self.params.pop_assign_file))

            try:
                ## parse populations file
                popdat = pd.read_csv(popfile,
                                     header=None,
                                     delim_whitespace=1,
                                     names=["inds", "pops"],
                                     comment="#")

                popdict = {
                    key: group.inds.values.tolist()
                    for key, group in popdat.groupby("pops")
                }

                ## parse minsamples per population if present (line with #)
                mindat = [
                    i.lstrip("#").lstrip().rstrip()
                    for i in open(popfile, 'r').readlines()
                    if i.startswith("#")
                ]

                if mindat:
                    popmins = {}
                    for i in range(len(mindat)):
                        minlist = mindat[i].replace(",", "").split()
                        popmins.update({
                            i.split(':')[0]: int(i.split(':')[1])
                            for i in minlist
                        })
                else:
                    raise IPyradError(MIN_SAMPLES_PER_POP_MALFORMED)

            except (ValueError, IOError):
                raise IPyradError(
                    "  Populations file malformed - {}".format(popfile))

        else:
            ## pop dict is provided by user
            if not popmins:
                popmins = {i: 1 for i in popdict}

        ## check popdict. Filter for bad samples
        ## Warn user but don't bail out, could be setting the pops file
        ## on a new assembly w/o any linked samples.
        # badsamples = [
        # i for i in itertools.chain(*popdict.values())
        # if i not in self.samples.keys()]

        # if any(badsamples):
        #     ip.logger.warn(
        #         "Some names from population input do not match Sample "\
        #       + "names: ".format(", ".join(badsamples)))
        #     ip.logger.warn("If this is a new assembly this is normal.")

        ## If popmins not set, just assume all mins are zero
        if not popmins:
            popmins = {i: 0 for i in popdict.keys()}

        ## check popmins
        ## cannot have higher min for a pop than there are samples in the pop
        popmax = {i: len(popdict[i]) for i in popdict}
        if not all([popmax[i] >= popmins[i] for i in popdict]):
            raise IPyradError(
                " minsample per pop value cannot be greater than the " +
                " number of samples in the pop. Modify the populations file.")

        ## return dict
        self.populations = {i: (popmins[i], popdict[i]) for i in popdict}
Exemple #2
0
def check_name(name):
    invalid_chars = (string.punctuation.replace("_", "").replace("-", "") +
                     " ")
    if any(char in invalid_chars for char in name):
        raise IPyradError(BAD_ASSEMBLY_NAME.format(name))
Exemple #3
0
    def branch_assembly(self):
        """ 
        Load the passed in assembly and create a branch. Copy it
        to a new assembly, and also write out the appropriate params.txt
        """

        # get arguments to branch command
        bargs = self.args.branch

        # get new name
        newname = bargs[0]

        # trim .txt if it was accidentally added
        if newname.endswith(".txt"):
            newname = newname[:-4]

        # look for subsample arguments
        if len(bargs) > 1:

            # parse str or file of names to include/drop
            subargs = bargs[1:]

            # is there a '-' indicating to drop
            remove = 0
            if subargs[0] == "-":
                remove = 1
                subargs = subargs[1:]

            # is sample list a file?
            if os.path.exists(subargs[0]):
                with open(subargs[0], 'r') as infile:
                    subsamples = [
                        i.split()[0] for i in infile.readlines() if i.strip()
                    ]
            else:
                subsamples = subargs

            # check subsample drop names
            fails = [i for i in subsamples if i not in self.data.samples.keys()]
            if any(fails):
                raise IPyradError(
                    "\n  Failed: unrecognized names, check spelling:\n  {}"
                        .format("\n  ".join([i for i in fails])))
            
            # if drop then get subtract list
            if remove:
                print("  dropping {} samples".format(len(subsamples)))
                subsamples = list(set(self.data.samples.keys()) - set(subsamples))

            # If the arg after the new param name is a file that exists
            new_data = self.data.branch(newname, subsamples)

        # keeping all samples
        else:
            new_data = self.data.branch(newname, None)

        print("  creating a new branch called '{}' with {} Samples"
              .format(new_data.name, len(new_data.samples)))

        print("  writing new params file to {}"
              .format("params-" + new_data.name + ".txt\n"))

        new_data.write_params(
            "params-" + new_data.name + ".txt",
            force=self.args.force)
Exemple #4
0
    def _link_barcodes(self):
        """
        Parses Sample barcodes to a dictionary from 'barcodes_path'. This 
        function is called whenever a barcode-ish param is changed. 
        """
        # find barcodefile
        barcodefile = glob.glob(self.params.barcodes_path)
        if not barcodefile:
            raise IPyradError(
                "Barcodes file not found. You entered: {}".format(
                    self.params.barcodes_path))

        # read in the file
        bdf = pd.read_csv(barcodefile[0], header=None, delim_whitespace=1)
        bdf = bdf.dropna()

        # make sure bars are upper case
        bdf[1] = bdf[1].str.upper()

        # if replicates are present
        if bdf[0].value_counts().max() > 1:

            # print a warning about dups if the data are not demultiplexed
            if not self.samples:
                self._print(
                    "Warning: technical replicates (same name) present.")

            # adds -technical-replicate-N to replicate names (NON_DEFAULT)
            # if not self.hackersonly.merge_technical_replicates:

            # get duplicated names
            repeated = (bdf[0].value_counts() > 1).index

            # labels technical reps in barcode dict
            for rep in repeated:
                farr = bdf[bdf[0] == rep]
                for idx, index in enumerate(farr.index):
                    bdf.loc[index,
                            0] = ("{}-technical-replicate-{}".format(rep, idx))

        # make sure chars are all proper
        if not all(bdf[1].apply(set("RKSYWMCATG").issuperset)):
            raise IPyradError(BAD_BARCODE)

        # store barcodes as a dict
        self.barcodes = dict(zip(bdf[0], bdf[1]))

        # 3rad/seqcap use multiplexed barcodes
        if "3rad" in self.params.datatype:
            if not bdf.shape[1] == 3:
                raise IPyradError(
                    "pair3rad datatype should have two barcodes per sample.")

            # We'll concatenate them with a plus and split them later
            bdf[2] = bdf[2].str.upper()
            self.barcodes = dict(zip(bdf[0], bdf[1] + "+" + bdf[2]))

        # check barcodes sample names
        backup = self.barcodes
        self.barcodes = {}
        for key, value in backup.items():
            key = "".join(
                [i.replace(i, "_") if i in BADCHARS else i for i in key])
            self.barcodes[key] = value
Exemple #5
0
def _loci_to_arr(loci, taxdict, mindict):
    """
    return a frequency array from a loci file for all loci with taxa from 
    taxdict and min coverage from mindict. 
    """

    ## get max length of loci
    maxlen = np.max(np.array([len(locus.split("\n")[0]) for locus in loci]))

    ## make the array (4 or 5) and a mask array to remove loci without cov
    nloci = len(loci)
    maxlen = np.max(np.array([len(locus.split("\n")[0]) for locus in loci]))
    keep = np.zeros(nloci, dtype=np.bool_)
    arr = np.zeros((nloci, 4, maxlen), dtype=np.float64)

    ## six rows b/c one for each p3, and for the fused p3 ancestor
    if len(taxdict) == 5:
        # arr = np.zeros((nloci, 6, 300), dtype=np.float64)
        arr = np.zeros((nloci, 6, maxlen), dtype=np.float64)

    ## if not mindict, make one that requires 1 in each taxon
    if isinstance(mindict, int):
        mindict = {i: mindict for i in taxdict}
    elif isinstance(mindict, dict):
        mindict = {i: mindict[i] for i in taxdict}
    else:
        mindict = {i: 1 for i in taxdict}

    ## raise error if names are not 'p[int]'
    allowed_names = ['p1', 'p2', 'p3', 'p4', 'p5']
    if any([i not in allowed_names for i in taxdict]):
        raise IPyradError(\
            "keys in taxdict must be named 'p1' through 'p4' or 'p5'")

    ## parse key names
    keys = sorted([i for i in taxdict.keys() if i[0] == 'p'])
    outg = keys[-1]

    ## grab seqs just for the good guys
    for loc in range(nloci):

        ## parse the locus
        lines = loci[loc].split("\n")[:-1]
        names = [i.split()[0] for i in lines]
        seqs = np.array([list(i.split()[1]) for i in lines])

        ## check that names cover the taxdict (still need to check by site)
        covs = [sum([j in names for j in taxdict[tax]]) >= mindict[tax] \
                for tax in taxdict]

        ## keep locus
        if all(covs):
            keep[loc] = True

            ## get the refseq
            refidx = np.where([i in taxdict[outg] for i in names])[0]
            refseq = seqs[refidx].view(np.uint8)
            ancestral = np.array([reftrick(refseq, GETCONS2)[:, 0]])

            ## freq of ref in outgroup
            iseq = _reffreq2(ancestral, refseq, GETCONS2)
            arr[loc, -1, :iseq.shape[1]] = iseq

            ## enter 4-taxon freqs
            if len(taxdict) == 4:
                for tidx, key in enumerate(keys[:-1]):

                    ## get idx of names in test tax
                    nidx = np.where([i in taxdict[key] for i in names])[0]
                    sidx = seqs[nidx].view(np.uint8)

                    ## get freq of sidx
                    iseq = _reffreq2(ancestral, sidx, GETCONS2)

                    ## fill it in
                    arr[loc, tidx, :iseq.shape[1]] = iseq

            else:

                ## entere p5; and fill it in
                iseq = _reffreq2(ancestral, refseq, GETCONS2)
                arr[loc, -1, :iseq.shape[1]] = iseq

                ## enter p1
                nidx = np.where([i in taxdict['p1'] for i in names])[0]
                sidx = seqs[nidx].view(np.uint8)
                iseq = _reffreq2(ancestral, sidx, GETCONS2)
                arr[loc, 0, :iseq.shape[1]] = iseq

                ## enter p2
                nidx = np.where([i in taxdict['p2'] for i in names])[0]
                sidx = seqs[nidx].view(np.uint8)
                iseq = _reffreq2(ancestral, sidx, GETCONS2)
                arr[loc, 1, :iseq.shape[1]] = iseq

                ## enter p3 with p4 masked, and p4 with p3 masked
                nidx = np.where([i in taxdict['p3'] for i in names])[0]
                nidy = np.where([i in taxdict['p4'] for i in names])[0]
                sidx = seqs[nidx].view(np.uint8)
                sidy = seqs[nidy].view(np.uint8)
                xseq = _reffreq2(ancestral, sidx, GETCONS2)
                yseq = _reffreq2(ancestral, sidy, GETCONS2)
                mask3 = xseq != 0
                mask4 = yseq != 0
                xseq[mask4] = 0
                yseq[mask3] = 0
                arr[loc, 2, :xseq.shape[1]] = xseq
                arr[loc, 3, :yseq.shape[1]] = yseq

                ## enter p34
                nidx = nidx.tolist() + nidy.tolist()
                sidx = seqs[nidx].view(np.uint8)
                iseq = _reffreq2(ancestral, sidx, GETCONS2)
                arr[loc, 4, :iseq.shape[1]] = iseq

    ## size-down array to the number of loci that have taxa for the test
    arr = arr[keep, :, :]

    ## size-down sites to
    arr = masknulls(arr)

    return arr, keep
Exemple #6
0
    def __init__(
        self,
        data,
        impute_method=None,
        imap=None,
        minmap=None,
        mincov=0.1,
        quiet=False,
        topcov=0.9,
        niters=5,
        #ncomponents=None,
        ld_block_size=0,
    ):

        # only check import at init
        if not sys.modules.get("sklearn"):
            raise IPyradError(_MISSING_SKLEARN)
        if not sys.modules.get("toyplot"):
            raise IPyradError(_MISSING_TOYPLOT)

        # init attributes
        self.quiet = quiet
        self.data = os.path.realpath(os.path.expanduser(data))

        # data attributes
        self.impute_method = impute_method
        self.mincov = mincov
        self.imap = (imap if imap else {})
        self.minmap = (minmap if minmap else {i: 1 for i in self.imap})
        self.topcov = topcov
        self.niters = niters
        self.ld_block_size = ld_block_size

        # where the resulting data are stored.
        self.pcaxes = "No results, you must first call .run()"
        self.variances = "No results, you must first call .run()"

        # to be filled
        self.snps = np.array([])
        self.snpsmap = np.array([])
        self.nmissing = 0

        # Works now. ld_block_size will have no effect on RAD data
        if self.data.endswith((".vcf", ".vcf.gz")):
            if not ld_block_size:
                self.ld_block_size = 20000
                if not self.quiet:
                    print(_IMPORT_VCF_INFO.format(self.ld_block_size))

            converter = vcf_to_hdf5(
                name=data.split("/")[-1].split(".vcf")[0],
                data=self.data,
                ld_block_size=self.ld_block_size,
                quiet=True,
            )
            # run the converter
            converter.run()
            # Set data to the new hdf5 file
            self.data = converter.database

        # load .snps and .snpsmap from HDF5
        first = (True if isinstance(self.impute_method, int) else quiet)
        ext = SNPsExtracter(
            self.data,
            self.imap,
            self.minmap,
            self.mincov,
            quiet=first,
        )

        # run snp extracter to parse data files
        ext.parse_genos_from_hdf5()
        self.snps = ext.snps
        self.snpsmap = ext.snpsmap
        self.names = ext.names
        self._mvals = ext._mvals

        # make imap for imputing if not used in filtering.
        if not self.imap:
            self.imap = {'1': self.names}
            self.minmap = {'1': 0.5}

        # record missing data per sample
        self.missing = pd.DataFrame(
            {
                "missing": [0.],
            },
            index=self.names,
        )
        miss = np.sum(self.snps == 9, axis=1) / self.snps.shape[1]
        for name in self.names:
            self.missing.missing[name] = round(miss[self.names.index(name)], 2)

        # impute missing data
        if (self.impute_method is not False) and self._mvals:
            self._impute_data()
Exemple #7
0
    def parse_genos_from_hdf5(self, return_as_characters=False):
        """
        Parse genotype calls from hdf5 snps file and store snpsmap
        for subsampling.
        """
        # load arrays from hdf5; could be more efficient...
        with h5py.File(self.data, 'r') as io5:

            # snps is used to filter multi-allel and indel containing 
            snps = io5["snps"][:]

            # snpsmap is used to subsample per locus 
            # [:, 1] is overwrit as a range over non-filtered sites below.
            snpsmap = io5["snpsmap"][:, :2]

            # genos are the actual calls we want, after filtering
            genos = io5["genos"][:]  # .sum(axis=2).T

            # report pre-filter
            self._print("Samples: {}".format(len(self.names)))
            self._print("Sites before filtering: {}".format(snps.shape[1]))

            # filter all sites containing an indel in selected samples
            mask0 = np.any(snps[self.sidxs, :] == 45, axis=0)
            self._print("Filtered (indels): {}".format(mask0.sum()))

            # filter all sites w/ multi-allelic in the selected samples
            mask1 = np.sum(genos[:, self.sidxs] == 2, axis=2).sum(axis=1).astype(bool)
            mask1 += np.sum(genos[:, self.sidxs] == 3, axis=2).sum(axis=1).astype(bool)
            self._print("Filtered (bi-allel): {}".format(mask1.sum()))

            # convert genos (e.g., 1/1) to genos sums (e.g., 2)
            diplo = genos.sum(axis=2).T

            # convert any summed missing (18) to 9
            diplo[diplo == 18] = 9

            # filter based on mincov of subsamples (default = 0.0)
            cov = np.sum(diplo[self.sidxs, :] != 9, axis=0)

            if isinstance(self.mincov, int):
                mask2 = cov < self.mincov

            elif isinstance(self.mincov, float):
                mask2 = cov < self.mincov * len(self.sidxs)

            else:
                raise IPyradError("mincov should be an int or float.")
            self._print("Filtered (mincov): {}".format(mask2.sum()))

            # filter based on per-population min coverages
            mask3 = np.zeros(snps.shape[1], dtype=np.bool_)
            if self.imap:
                for key, val in self.imap.items():
                    mincov = self.minmap[key]
                    pidxs = np.array(sorted(self.dbnames.index(i) for i in val))
                    try:
                        subarr = diplo[pidxs, :]
                    except IndexError:
                        raise IPyradError("imap is empty: {} - {}".format(key, val))
                    counts = np.sum(subarr != 9, axis=0)
                    if isinstance(mincov, float):
                        mask3 += (counts / subarr.shape[0]) < mincov
                    elif isinstance(mincov, int):
                        mask3 += counts < mincov
                    else:
                        raise IPyradError("minmap dictionary malformed.")
            self._print("Filtered (minmap): {}".format(mask3.sum()))

            # filter based on whether subsample still is variable at site
            # after masking missing data values, while also collapsing data
            # back into diploid genotype calls where haplo-missing is masked.
            marr = np.ma.array(
                data=diplo[self.sidxs, :], 
                mask=diplo[self.sidxs, :] == 9,
            )

            # round to the most common call (0, 1, 2)
            common = marr.mean(axis=0).round().astype(int)

            # mask if any non-mask data is not the common genotype and invert
            mask4 = np.invert(np.any(marr != common, axis=0).data)
            self._print("Filtered (subsample invariant): {}".format(mask4.sum()))

            # maf filter: applies to all gene copies (haploids)
            if isinstance(self.maf, int):
                freqs = marr.sum(axis=0)
            else:
                freqs = marr.sum(axis=0) / (2 * np.sum(marr.mask == False, axis=0))
                freqs[freqs > 0.5] = 1 - (freqs[freqs > 0.5])
            mask5 = freqs < self.maf

            # only report filters unique to mask5
            masks = mask0 + mask1 + mask2 + mask3 + mask4
            drop = mask5[~masks].sum()
            self._print("Filtered (minor allele frequency): {}".format(drop))

            # apply the filters
            summask = mask0 + mask1 + mask2 + mask3 + mask4 + mask5
            allmask = np.invert(summask)

            # store filtered int genotype calls (the output used by most tools)
            self.snps = diplo[self.sidxs, :][:, allmask]

            # total report
            totalsnps = summask.sum() + mask5.sum()
            self._print("Filtered (combined): {}".format(totalsnps))

            # bail out if ALL snps were filtered
            if self.snps.size == 0:
                raise IPyradError("No SNPs passed filtering.")

            # report state
            self._print(
                "Sites after filtering: {}".format(self.snps.shape[1])
            )
            self._mvals = np.sum(self.snps == 9)
            self._msites = np.any(self.snps == 9, axis=0).sum()
            self._print(
                "Sites containing missing values: {} ({:.2f}%)"
                .format(
                    self._msites, 
                    100 * self._msites / self.snps.shape[1],
                )
            )
            self._print(
                "Missing values in SNP matrix: {} ({:.2f}%)"
                .format(
                    self._mvals, 
                    100 * self._mvals / self.snps.size,                
                )
            )

            # apply mask to snpsmap to map snps
            snpsmap = snpsmap[allmask, :]
            snpsmap[:, 1] = range(snpsmap.shape[0])
            self.snpsmap = snpsmap
            del snpsmap

            # final report
            self._print(
                "SNPs (total): {}\nSNPs (unlinked): {}".format(
                    self.snps.shape[1], 
                    np.unique(self.snpsmap[:, 0]).size
                )
            )

            # overwrite geno calls with str data.
            # store the filtered SNP calls (actual A,C,T or G) 
            if return_as_characters:
                self.snps = snps[self.sidxs, :][:, allmask].view("S1")
Exemple #8
0
    def plot(
            self,
            show_test_labels=True,
            use_edge_lengths=True,
            collapse_outgroup=False,
            pct_tree_x=0.5,
            pct_tree_y=0.2,
            subset_tests=None,
            #toytree_kwargs=None,
            *args,
            **kwargs):
        """ 
        Draw a multi-panel figure with tree, tests, and results 
        
        Parameters:
        -----------
        height: int
        ...

        width: int
        ...

        show_test_labels: bool
        ...

        use_edge_lengths: bool
        ...

        collapse_outgroups: bool
        ...

        pct_tree_x: float
        ...

        pct_tree_y: float
        ...

        subset_tests: list
        ...

        ...

        """
        print("Plotting baba results is not implemented in v.0.9.")
        return

        ## check for attributes
        if not self.newick:
            raise IPyradError("baba plot requires a newick treefile")
        if not self.tests:
            raise IPyradError("baba plot must have a .tests attribute")

        ## ensure tests is a list
        if isinstance(self.tests, dict):
            self.tests = [self.tests]

        ## re-decompose the tree
        ttree = toytree.tree(
            self.newick,
            orient='down',
            use_edge_lengths=use_edge_lengths,
        )

        ## subset test to show fewer
        if subset_tests != None:
            #tests = self.tests[subset_tests]
            tests = [self.tests[i] for i in subset_tests]
            boots = self.results_boots[subset_tests]
        else:
            tests = self.tests
            boots = self.results_boots
Exemple #9
0
    def parse_genos_from_hdf5(self):
        """
        Parse genotype calls from hdf5 snps file and store snpsmap
        for subsampling.
        """
        # load arrays from hdf5; could be more efficient...
        io5 = h5py.File(self.data, 'r')

        # snps is used to filter multi-allel and indel containing
        snps = io5["snps"][:]

        # snpsmap is used to subsample per locus
        snpsmap = io5["snpsmap"][:, :2]

        # genos are the actual calls we want, after filtering
        genos = io5["genos"][:]  # .sum(axis=2).T

        # report pre-filter
        self._print("Samples: {}".format(len(self.names)))
        self._print("Sites before filtering: {}".format(snps.shape[1]))

        # filter all sites containing an indel in selected samples
        mask0 = np.any(snps[self.sidxs, :] == 45, axis=0)
        self._print("Filtered (indels): {}".format(mask0.sum()))

        # filter all sites w/ multi-allelic
        mask1 = np.sum(genos[:, self.sidxs] == 2,
                       axis=2).sum(axis=1).astype(bool)
        mask1 += np.sum(genos[:, self.sidxs] == 3,
                        axis=2).sum(axis=1).astype(bool)
        self._print("Filtered (bi-allel): {}".format(mask1.sum()))

        # convert genos (e.g., 0/1) to genos sums (e.g., 2)
        genos = genos.sum(axis=2).T
        # convert any summed missing (18) to 9
        genos[genos == 18] = 9

        # filter based on total sample min coverage
        cov = np.sum(genos[self.sidxs, :] != 9, axis=0)
        if isinstance(self.mincov, int):
            mask2 = cov < self.mincov
        elif isinstance(self.mincov, float):
            mask2 = cov < self.mincov * len(self.sidxs)
        else:
            raise IPyradError("mincov should be an int or float.")
        self._print("Filtered (mincov): {}".format(mask2.sum()))

        # filter based on per-population min coverages
        if not self.imap:
            mask3 = np.zeros(snps.shape[1], dtype=np.bool_)
        else:
            mask3 = np.zeros(snps.shape[1], dtype=np.bool_)
            for key, val in self.imap.items():
                mincov = self.minmap[key]
                pidxs = np.array(sorted(self.dbnames.index(i) for i in val))
                subarr = genos[pidxs, :]
                counts = np.sum(subarr != 9, axis=0)
                if isinstance(mincov, float):
                    mask3 += (counts / subarr.shape[0]) < mincov
                elif isinstance(mincov, int):
                    mask3 += counts < mincov
                else:
                    raise IPyradError("minmap dictionary malformed.")
        self._print("Filtered (minmap): {}".format(mask3.sum()))

        # apply mask to snps
        summask = mask0 + mask1 + mask2 + mask3
        allmask = np.invert(summask)
        self._print("Filtered (combined): {}".format(summask.sum()))
        self.snps = genos[self.sidxs, :][:, allmask]

        # bail out if ALL snps were filtered
        if self.snps.size == 0:
            raise IPyradError("No SNPs passed filtering.")

        # report state
        self._print("Sites after filtering: {}".format(self.snps.shape[1]))
        self._mvals = np.sum(self.snps == 9)
        self._msites = np.any(self.snps == 9, axis=0).sum()
        self._print("Sites containing missing values: {} ({:.2f}%)".format(
            self._msites,
            100 * self._msites / self.snps.shape[1],
        ))
        self._print("Missing values in SNP matrix: {} ({:.2f}%)".format(
            self._mvals,
            100 * self._mvals / self.snps.size,
        ))

        # apply mask to snpsmap to map snps
        snpsmap = snpsmap[allmask, :]
        snpsmap[:, 1] = range(snpsmap.shape[0])
        self.snpsmap = snpsmap
        del snpsmap

        # close it up
        io5.close()
Exemple #10
0
    def __init__(
        self,
        pcatool,
        ax0=0,
        ax1=1,
        cycle=8,
        colors=None,
        opacity=None,
        shapes=None,
        size=12,
        legend=True,
        label='',
        outfile='',
        imap=None,
        width=400, 
        height=300,
        axes=None,
        **kwargs):
        """
        See .draw() function above for docstring.
        """
        self.pcatool = pcatool
        self.datas = self.pcatool.pcaxes
        self.names = self.pcatool.names
        self.imap = (imap if imap else self.pcatool.imap)
        self.ax0 = ax0
        self.ax1 = ax1
        self.axes = axes

        # checks on user args
        self.cycle = cycle
        self.colors = colors
        self.shapes = shapes
        self.opacity = opacity
        self.size = size
        self.legend = legend
        self.label = label
        self.outfile = outfile
        self.height = height
        self.width = width

        # parse attrs from the data
        self.nreplicates = None
        self.variance = None
        self._parse_replicate_runs()
        self._regress_replicates()

        # setup canvas and axes or use user supplied axes
        self.canvas = None
        self.axes = axes
        self._setup_canvas_and_axes()

        # add markers to the axes
        self.rstyles = {}
        self.pstyles = {}
        self._get_marker_styles()
        self._assign_styles_to_marks()
        self._draw_markers()

        # add the legend
        if self.legend and (self.canvas is not None):
            self._add_legend()

        # Write to pdf/svg
        if self.outfile and (self.canvas is not None):
            if self.outfile.endswith(".pdf"):
                toyplot.pdf.render(self.canvas, self.outfile)
            elif self.outfile.endswith(".svg"):
                toyplot.svg.render(self.canvas, self.outfile)
            else:
                raise IPyradError("outfile only supports pdf/svg.")
Exemple #11
0
    def parse_genos_from_hdf5(self):
        """
        Parse genotype calls from hdf5 snps file and store snpsmap
        for subsampling.
        """
        # load arrays from hdf5; could be more efficient...
        with h5py.File(self.data, 'r') as io5:

            # snps is used to filter multi-allel and indel containing
            snps = io5["snps"][:]

            # snpsmap is used to subsample per locus
            # [:, 1] is overwrit as a range over non-filtered sites below.
            snpsmap = io5["snpsmap"][:, :2]

            # genos are the actual calls we want, after filtering
            genos = io5["genos"][:]  # .sum(axis=2).T

            # report pre-filter
            self._print("Samples: {}".format(len(self.names)))
            self._print("Sites before filtering: {}".format(snps.shape[1]))

            # filter all sites containing an indel in selected samples
            mask0 = np.any(snps[self.sidxs, :] == 45, axis=0)
            self._print("Filtered (indels): {}".format(mask0.sum()))

            # filter all sites w/ multi-allelic
            mask1 = np.sum(genos[:, self.sidxs] == 2,
                           axis=2).sum(axis=1).astype(bool)
            mask1 += np.sum(genos[:, self.sidxs] == 3,
                            axis=2).sum(axis=1).astype(bool)
            self._print("Filtered (bi-allel): {}".format(mask1.sum()))

            # convert genos (e.g., 1/1) to genos sums (e.g., 2)
            diplo = genos.sum(axis=2).T

            # convert any summed missing (18) to 9
            diplo[diplo == 18] = 9

            # filter based on mincov of subsamples (default = 0.0)
            cov = np.sum(diplo[self.sidxs, :] != 9, axis=0)

            if isinstance(self.mincov, int):
                mask2 = cov < self.mincov

            elif isinstance(self.mincov, float):
                mask2 = cov < self.mincov * len(self.sidxs)

            else:
                raise IPyradError("mincov should be an int or float.")
            self._print("Filtered (mincov): {}".format(mask2.sum()))

            # filter based on per-population min coverages
            if not self.imap:
                mask3 = np.zeros(snps.shape[1], dtype=np.bool_)

            else:
                mask3 = np.zeros(snps.shape[1], dtype=np.bool_)
                for key, val in self.imap.items():
                    mincov = self.minmap[key]
                    pidxs = np.array(sorted(
                        self.dbnames.index(i) for i in val))
                    try:
                        subarr = diplo[pidxs, :]
                    except IndexError:
                        raise IPyradError("imap is empty: {} - {}".format(
                            key, val))
                    counts = np.sum(subarr != 9, axis=0)
                    if isinstance(mincov, float):
                        mask3 += (counts / subarr.shape[0]) < mincov
                    elif isinstance(mincov, int):
                        mask3 += counts < mincov
                    else:
                        raise IPyradError("minmap dictionary malformed.")
            self._print("Filtered (minmap): {}".format(mask3.sum()))

            # filter based on whether subsample still is variable at site
            # after masking missing data values
            marr = np.ma.array(
                data=diplo[self.sidxs, :],
                mask=diplo[self.sidxs, :] == 9,
            )

            # round to the most common call (0, 1, 2)
            common = marr.mean(axis=0).round().astype(int)

            # mask if any non-mask data is not the common genotype and invert
            mask4 = np.invert(np.any(marr != common, axis=0).data)
            self._print("Filtered (subsample invariant): {}".format(
                mask4.sum()))

            # apply mask to snps
            summask = mask0 + mask1 + mask2 + mask3 + mask4
            allmask = np.invert(summask)
            self._print("Filtered (combined): {}".format(summask.sum()))
            self.snps = diplo[self.sidxs, :][:, allmask]

            # bail out if ALL snps were filtered
            if self.snps.size == 0:
                raise IPyradError("No SNPs passed filtering.")

            # report state
            self._print("Sites after filtering: {}".format(self.snps.shape[1]))
            self._mvals = np.sum(self.snps == 9)
            self._msites = np.any(self.snps == 9, axis=0).sum()
            self._print("Sites containing missing values: {} ({:.2f}%)".format(
                self._msites,
                100 * self._msites / self.snps.shape[1],
            ))
            self._print("Missing values in SNP matrix: {} ({:.2f}%)".format(
                self._mvals,
                100 * self._mvals / self.snps.size,
            ))

            # apply mask to snpsmap to map snps
            snpsmap = snpsmap[allmask, :]
            snpsmap[:, 1] = range(snpsmap.shape[0])
            self.snpsmap = snpsmap
            del snpsmap
Exemple #12
0
    def run(
        self, 
        ipyclient=None, 
        quiet=False,
        force=False,
        block=False,
        ):
        """
        Submits mrbayes job to run. If no ipyclient object is provided then 
        the function will block until the mb run is finished. If an ipyclient
        is provided then the job is sent to a remote engine and an asynchronous 
        result object is returned which can be queried or awaited until it 
        finishes.

        Parameters
        -----------
        ipyclient:
            Not yet supported... 
        quiet: 
            suppress print statements
        force:
            overwrite existing results files with this job name. 
        block:
            will block progress in notebook until job finishes, even if job
            is running on a remote ipyclient.
        """

        # check for input data file
        if not os.path.exists(self.data):
            raise IPyradError("data file not found {}".format(self.data))

        # stop before trying in mrbayes
        if force:
            for key, oldfile in self.trees:
                if os.path.exists(oldfile):
                    os.remove(oldfile)
        if os.path.exists(self.trees.pstat):
            print("Error Files Exist: set a new name or use Force flag.\n{}"
                  .format(self.trees.pstat))
            return 

        # rewrite nexus file in case params have been updated
        self._write_nexus_file()

        # submit it
        if not ipyclient:
            self.stdout = _call_mb([self.binary, self.nexus])

        else:
            # find all hosts and submit job to host with most available engines
            lbview = ipyclient.load_balanced_view()
            self.rasync = lbview.apply(
                _call_mb, [self.binary, self.nexus])

        # initiate random seed
        if not quiet:
            if not ipyclient:
                print("job {} finished successfully".format(self.name))

            else:               
                if block:
                    print("job {} running".format(self.name))
                    ipyclient.wait()
                    if self.rasync.successful():
                        print(
                            "job {} finished successfully"
                            .format(self.name))
                    else:
                        self.rasync.get()
                else:
                    print("job {} submitted to cluster".format(self.name))
Exemple #13
0
def tree2tests(newick, constraint_dict, constraint_exact):
    """
    Returns dict of all possible four-taxon splits in a tree. Assumes
    the user has entered a rooted tree. Skips polytomies.
    """
    # make tree
    tree = toytree.tree(newick)
    if not tree.is_rooted():
        raise IPyradError(
            "Input tree must be rooted to use generate_tests_from_tree()")

    # store results
    testset = set()

    # constraints fill in empty
    cdict = OrderedDict((i, []) for i in ["p1", "p2", "p3", "p4"])
    if constraint_dict:
        cdict.update(constraint_dict)

    # expand constraint_exact if list
    if isinstance(constraint_exact, bool):
        constraint_exact = [constraint_exact] * 4

    if isinstance(constraint_exact, list):
        if len(constraint_exact) != len(cdict):
            raise Exception(
                "constraint_exact must be bool or list of bools of length N")

    # traverse root to tips. Treat the left as outgroup, then the right.
    tests = []

    # topnode must have children. All traversals use default "levelorder"
    for topnode in tree.treenode.traverse():

        for oparent in topnode.children:
            for onode in oparent.traverse("levelorder"):
                if test_constraint(onode, cdict, "p4", constraint_exact[3]):
                    #print(topnode.name, onode.name)

                    ## p123 parent is sister to oparent
                    p123parent = oparent.get_sisters()[0]
                    for p123node in p123parent.traverse("levelorder"):

                        for p3parent in p123node.children:
                            for p3node in p3parent.traverse("levelorder"):
                                if test_constraint(p3node, cdict, "p3",
                                                   constraint_exact[2]):
                                    #print(topnode.name, onode.name, p3node.name)

                                    ## p12 parent is sister to p3parent
                                    p12parent = p3parent.get_sisters()[0]
                                    for p12node in p12parent.traverse(
                                            "levelorder"):

                                        for p2parent in p12node.children:
                                            for p2node in p2parent.traverse(
                                                    "levelorder"):
                                                if test_constraint(
                                                        p2node, cdict, "p2",
                                                        constraint_exact[1]):

                                                    ## p12 parent is sister to p3parent
                                                    p1parent = p2parent.get_sisters(
                                                    )[0]
                                                    for p1node in p1parent.traverse(
                                                            "levelorder"):
                                                        #for p1parent in p1node.children:
                                                        #    for p1node in p1parent.traverse("levelorder"):
                                                        if test_constraint(
                                                                p1node, cdict,
                                                                "p1",
                                                                constraint_exact[
                                                                    0]):
                                                            x = (onode.name,
                                                                 p3node.name,
                                                                 p2node.name,
                                                                 p1node.name)
                                                            test = {}
                                                            test[
                                                                'p4'] = onode.get_leaf_names(
                                                                )
                                                            test[
                                                                'p3'] = p3node.get_leaf_names(
                                                                )
                                                            test[
                                                                'p2'] = p2node.get_leaf_names(
                                                                )
                                                            test[
                                                                'p1'] = p1node.get_leaf_names(
                                                                )
                                                            if x not in testset:
                                                                tests.append(
                                                                    test)
                                                                testset.add(x)
    return tests
Exemple #14
0
def collate_files(data, sname, tmp1s, tmp2s):
    """ 
    Collate temp fastq files in tmp-dir into 1 gzipped sample.
    """
    ## out handle
    out1 = os.path.join(data.dirs.fastqs, "{}_R1_.fastq.gz".format(sname))
    out = io.BufferedWriter(gzip.open(out1, 'w'))

    ## build cmd
    cmd1 = ['cat']
    for tmpfile in tmp1s:
        cmd1 += [tmpfile]

    ## compression function
    proc = sps.Popen(['which', 'pigz'], stderr=sps.PIPE,
                     stdout=sps.PIPE).communicate()
    if proc[0].strip():
        compress = ["pigz"]
    else:
        compress = ["gzip"]

    ## call cmd
    proc1 = sps.Popen(cmd1, stderr=sps.PIPE, stdout=sps.PIPE)
    proc2 = sps.Popen(compress,
                      stdin=proc1.stdout,
                      stderr=sps.PIPE,
                      stdout=out)
    err = proc2.communicate()
    if proc2.returncode:
        raise IPyradError("error in collate_files R1 %s", err)
    proc1.stdout.close()
    out.close()

    ## then cleanup
    for tmpfile in tmp1s:
        os.remove(tmpfile)

    if 'pair' in data.params.datatype:
        ## out handle
        out2 = os.path.join(data.dirs.fastqs, "{}_R2_.fastq.gz".format(sname))
        out = io.BufferedWriter(gzip.open(out2, 'w'))

        ## build cmd
        cmd1 = ['cat']
        for tmpfile in tmp2s:
            cmd1 += [tmpfile]

        ## call cmd
        proc1 = sps.Popen(cmd1, stderr=sps.PIPE, stdout=sps.PIPE)
        proc2 = sps.Popen(compress,
                          stdin=proc1.stdout,
                          stderr=sps.PIPE,
                          stdout=out)
        err = proc2.communicate()
        if proc2.returncode:
            raise IPyradError("error in collate_files R2 %s", err)
        proc1.stdout.close()
        out.close()

        ## then cleanup
        for tmpfile in tmp2s:
            os.remove(tmpfile)
Exemple #15
0
    def draw(self,
             ax0=0,
             ax1=1,
             cycle=8,
             colors=None,
             shapes=None,
             size=10,
             legend=True,
             width=400,
             height=300,
             **kwargs):
        """
        Draw a scatterplot for data along two PC axes. 
        """
        try:
            # check for replicates in the data
            datas = self.pcaxes
            nreplicates = len(datas)
            variance = np.array([i for i in self.variances.values()
                                 ]).mean(axis=0)
        except AttributeError:
            raise IPyradError(
                "You must first call run() before calling draw().")

        # check that requested axes exist
        assert max(ax0, ax1) < self.pcaxes[0].shape[1], (
            "data set only has {} axes.".format(self.pcaxes[0].shape[1]))

        # test reversions of replicate axes (clumpp like) so that all plot
        # in the same orientation as replicate 0.
        model = LinearRegression()
        for i in range(1, len(datas)):
            for ax in [ax0, ax1]:
                orig = datas[0][:, ax].reshape(-1, 1)
                new = datas[i][:, ax].reshape(-1, 1)
                swap = (datas[i][:, ax] * -1).reshape(-1, 1)

                # get r^2 for both model fits
                model.fit(orig, new)
                c0 = model.coef_[0][0]
                model.fit(orig, swap)
                c1 = model.coef_[0][0]

                # if swapped fit is better make this the data
                if c1 > c0:
                    datas[i][:, ax] = datas[i][:, ax] * -1

        # make reverse imap dictionary
        irev = {}
        for pop, vals in self.imap.items():
            for val in vals:
                irev[val] = pop

        # the max number of pops until color cycle repeats
        # If the passed in number of colors is big enough to cover
        # the number of pops then set cycle to len(colors)
        # If colors == None this first `if` falls through (lazy evaluation)
        if colors and len(colors) >= len(self.imap):
            cycle = len(colors)
        else:
            cycle = min(cycle, len(self.imap))

        # get color list repeating in cycles of cycle
        if not colors:
            colors = itertools.cycle(
                toyplot.color.broadcast(
                    toyplot.color.brewer.map("Spectral"),
                    shape=cycle,
                ))
        else:
            colors = iter(colors)
            # assert len(colors) == len(imap), "len colors must match len imap"

        # get shapes list repeating in cycles of cycle up to 5 * cycle
        if not shapes:
            shapes = itertools.cycle(
                np.concatenate([
                    np.tile("o", cycle),
                    np.tile("s", cycle),
                    np.tile("^", cycle),
                    np.tile("d", cycle),
                    np.tile("v", cycle),
                    np.tile("<", cycle),
                    np.tile("x", cycle),
                ]))
        else:
            shapes = iter(shapes)
        # else:
        # assert len(shapes) == len(imap), "len colors must match len imap"

        # assign styles to populations and to legend markers (no replicates)
        pstyles = {}
        rstyles = {}
        for idx, pop in enumerate(self.imap):

            color = next(colors)
            shape = next(shapes)

            pstyles[pop] = toyplot.marker.create(
                size=size,
                shape=shape,
                mstyle={
                    "fill": toyplot.color.to_css(color),
                    "stroke": "#262626",
                    "stroke-width": 1.0,
                    "fill-opacity": 0.75,
                },
            )
            rstyles[pop] = toyplot.marker.create(
                size=size,
                shape=shape,
                mstyle={
                    "fill": toyplot.color.to_css(color),
                    "stroke": "none",
                    "fill-opacity": 0.9 / nreplicates,
                },
            )

        # assign styled markers to data points
        pmarks = []
        rmarks = []
        for name in self.names:
            pop = irev[name]
            pmark = pstyles[pop]
            pmarks.append(pmark)
            rmark = rstyles[pop]
            rmarks.append(rmark)

        # get axis labels for PCA or TSNE plot
        if variance[ax0] >= 0.0:
            xlab = "PC{} ({:.1f}%) explained".format(ax0, variance[ax0] * 100)
            ylab = "PC{} ({:.1f}%) explained".format(ax1, variance[ax1] * 100)
        else:
            xlab = "TNSE component 1"
            ylab = "TNSE component 2"

        # plot points with colors x population
        canvas = toyplot.Canvas(width, height)  # 400, 300)
        axes = canvas.cartesian(
            grid=(1, 5, 0, 1, 0, 4),
            xlabel=xlab,
            ylabel=ylab,
        )

        # if not replicates then just plot the points
        if nreplicates < 2:
            mark = axes.scatterplot(
                datas[0][:, ax0],
                datas[0][:, ax1],
                marker=pmarks,
                title=self.names,
            )

        # replicates show clouds plus centroids
        else:
            # add the replicates cloud points
            for i in range(nreplicates):
                # get transformed coordinates and variances
                mark = axes.scatterplot(
                    datas[i][:, ax0],
                    datas[i][:, ax1],
                    marker=rmarks,
                )

            # compute centroids
            Xarr = np.concatenate([
                np.array([datas[i][:, ax0], datas[i][:, ax1]]).T
                for i in range(nreplicates)
            ])
            yarr = np.tile(np.arange(len(self.names)), nreplicates)
            clf = NearestCentroid()
            clf.fit(Xarr, yarr)

            # draw centroids
            mark = axes.scatterplot(
                clf.centroids_[:, 0],
                clf.centroids_[:, 1],
                title=self.names,
                marker=pmarks,
            )

        # add a legend
        if legend:
            if len(self.imap) > 1:
                marks = [(pop, marker) for pop, marker in pstyles.items()]
                canvas.legend(marks,
                              corner=("right", 35, 100,
                                      min(250,
                                          len(pstyles) * 25)))
        return canvas, axes, mark
Exemple #16
0
def zcat_make_temps(data, ftup, num, tmpdir, optim, start):
    """ 
    Call bash command 'cat' and 'split' to split large files into 4 bits.
    """
    # read it, is it gzipped?
    catcmd = ["cat"]
    if ftup[0].endswith(".gz"):
        catcmd = ["gunzip", "-c"]

    # get reading commands for r1s, r2s
    cmd1 = catcmd + [ftup[0]]
    cmd2 = catcmd + [ftup[1]]

    # make name prefix
    chunk1 = os.path.join(tmpdir, "chunk1_{}_".format(num))
    chunk2 = os.path.join(tmpdir, "chunk2_{}_".format(str(num)))

    # command to split and write to prefix
    cmd3 = ["split", "-a", "4", "-l", str(int(optim) * 4), "-", chunk1]
    cmd4 = ["split", "-a", "4", "-l", str(int(optim) * 4), "-", chunk2]

    # start 'split ... | gunzip -c rawfile'
    proc1 = sps.Popen(cmd1,
                      stderr=sps.STDOUT,
                      stdout=sps.PIPE,
                      universal_newlines=True)
    proc3 = sps.Popen(cmd3,
                      stderr=sps.STDOUT,
                      stdout=sps.PIPE,
                      stdin=proc1.stdout,
                      universal_newlines=True)
    res = proc3.communicate()[0]
    if proc3.returncode:
        raise IPyradError("error in zcat_make_temps:\n{}".format(res))

    # grab output handle results from read1s
    chunks1 = sorted(glob.glob(chunk1 + "*"))

    # repeat for paired reads
    if "pair" in data.params.datatype:
        proc2 = sps.Popen(cmd2,
                          stderr=sps.STDOUT,
                          stdout=sps.PIPE,
                          universal_newlines=True)
        proc4 = sps.Popen(cmd4,
                          stderr=sps.STDOUT,
                          stdout=sps.PIPE,
                          stdin=proc2.stdout,
                          universal_newlines=True)
        res = proc4.communicate()[0]
        if proc4.returncode:
            raise IPyradError("error in zcat_make_temps:\n{}".format(res))
        chunks2 = sorted(glob.glob(chunk2 + "*"))
    else:
        chunks2 = [0] * len(chunks1)

    # ensure r1==r2
    assert len(chunks1) == len(chunks2), "Different number of R1 and R2 files"

    # ensure full progress bar b/c estimates njobs could be off
    return list(zip(chunks1, chunks2))
Exemple #17
0
 def pcs(self, rep=0):
     try:
         df = pd.DataFrame(self.pcaxes[rep], index=self.names)
     except ValueError:
         raise IPyradError("You must call run() before accessing the pcs.")
     return df
Exemple #18
0
def batch(
    baba,
    ipyclient=None,
):
    """
    distributes jobs to the parallel client
    """

    ## parse args
    handle = baba.data
    taxdicts = baba.tests
    mindicts = baba.params.mincov
    nboots = baba.params.nboots

    ## if ms generator make into reusable list
    sims = 0
    if isinstance(handle, types.GeneratorType):
        handle = list(handle)
        sims = 1
    else:
        ## expand locifile path to full path
        handle = os.path.realpath(handle)

    ## parse taxdicts into names and lists if it a dictionary
    #if isinstance(taxdicts, dict):
    #    names, taxdicts = taxdicts.keys(), taxdicts.values()
    #else:
    #    names = []
    names = []
    if isinstance(taxdicts, dict):
        taxdicts = [taxdicts]

    ## an array to hold results (len(taxdicts), nboots)
    tot = len(taxdicts)
    resarr = np.zeros((tot, 7), dtype=np.float64)
    bootsarr = np.zeros((tot, nboots), dtype=np.float64)
    paneldict = {}

    ## TODO: Setup a wrapper to find and cleanup ipyclient
    ## define the function and parallelization to use,
    ## if no ipyclient then drops back to using multiprocessing.
    if not ipyclient:
        # ipyclient = ip.core.parallel.get_client(**self._ipcluster)
        raise IPyradError("you must enter an ipyparallel.Client() object")
    else:
        lbview = ipyclient.load_balanced_view()

    ## submit jobs to run on the cluster queue
    start = time.time()
    asyncs = {}
    idx = 0

    ## prepare data before sending to engines
    ## if it's a str (locifile) then parse it here just once.
    if isinstance(handle, str):
        with open(handle, 'r') as infile:
            loci = infile.read().strip().split("|\n")
    if isinstance(handle, list):
        pass  #sims()

    ## iterate over tests (repeats mindicts if fewer than taxdicts)
    itests = iter(taxdicts)
    imdict = itertools.cycle([mindicts])

    #for test, mindict in zip(taxdicts, itertools.cycle([mindicts])):
    for i in range(len(ipyclient)):

        ## next entries unless fewer than len ipyclient, skip
        try:
            test = next(itests)
            mindict = next(imdict)
        except StopIteration:
            continue

        ## if it's sim data then convert to an array
        if sims:
            loci = _msp_to_arr(handle, test)
            args = (loci, test, mindict, nboots)
            print("not yet implemented")
            #asyncs[idx] = lbview.apply_async(dstat, *args)
        else:
            args = [loci, test, mindict, nboots]
            asyncs[idx] = lbview.apply(dstat, *args)
        idx += 1

    ## block until finished, print progress if requested.
    finished = 0
    try:
        while 1:
            keys = [i for (i, j) in asyncs.items() if j.ready()]
            ## check for failures
            for job in keys:
                if not asyncs[job].successful():
                    raise IPyradError(\
                        " error: {}: {}".format(job, asyncs[job].exception()))
                ## enter results for successful jobs
                else:
                    _res, _bot = asyncs[job].result()

                    ## store D4 results
                    if _res.shape[0] == 1:
                        resarr[job] = _res.T.values[:, 0]
                        bootsarr[job] = _bot

                    ## or store D5 results
                    else:
                        paneldict[job] = _res.T

                    ## remove old job
                    del asyncs[job]
                    finished += 1

                    ## submit next job if there is one.
                    try:
                        test = next(itests)
                        mindict = next(imdict)
                        if sims:
                            loci = _msp_to_arr(handle, test)
                            args = (loci, test, mindict, nboots)
                            print("not yet implemented")
                            #asyncs[idx] = lbview.apply_async(dstat, *args)
                        else:
                            args = [loci, test, mindict, nboots]
                            asyncs[idx] = lbview.apply(dstat, *args)
                        idx += 1
                    except StopIteration:
                        pass

            ## count finished and break if all are done.
            #fin = idx - len(asyncs)
            elap = datetime.timedelta(seconds=int(time.time() - start))
            printstr = " calculating D-stats"
            progressbar(finished, tot, start, message=printstr)
            time.sleep(0.1)
            if not asyncs:
                print("")
                break

    except KeyboardInterrupt as inst:
        ## cancel all jobs (ipy & multiproc modes) and then raise error
        try:
            ipyclient.abort()
        except Exception:
            pass
        raise inst

    ## dress up resarr as a Pandas DataFrame if 4-part test
    if len(test) == 4:
        if not names:
            names = range(len(taxdicts))
        #print("resarr")
        #print(resarr)
        resarr = pd.DataFrame(resarr,
                              index=names,
                              columns=[
                                  "dstat", "bootmean", "bootstd", "Z", "ABBA",
                                  "BABA", "nloci"
                              ])

        ## sort results and bootsarr to match if test names were supplied
        resarr = resarr.sort_index()
        order = [list(resarr.index).index(i) for i in names]
        bootsarr = bootsarr[order]
        return resarr, bootsarr
    else:
        ## order results dfs
        listres = []
        for key in range(len(paneldict)):
            listres.append(paneldict[key])

        ## make into a multi-index dataframe
        ntests = len(paneldict)
        multi_index = [
            np.array([[i] * 3 for i in range(ntests)]).flatten(),
            np.array(['p3', 'p4', 'shared'] * ntests),
        ]
        resarr = pd.DataFrame(
            data=pd.concat(listres).values,
            index=multi_index,
            columns=listres[0].columns,
        )
        return resarr, None
Exemple #19
0
    def plot(self,
             show_test_labels=True,
             use_edge_lengths=False,
             collapse_outgroup=False,
             pct_tree_x=0.5,
             pct_tree_y=0.2,
             subset_tests=None,
             prune_tree_to_tests=False,
             *args,
             **kwargs):
        """ 
        Draw a multi-panel figure with tree, tests, and results 

        Parameters:
        -----------
        height: int
        ...

        width: int
        ...

        show_test_labels: bool
        ...

        use_edge_lengths: bool
        ...

        collapse_outgroups: bool
        ...

        pct_tree_x: float
        ...

        pct_tree_y: float
        ...

        subset_tests: list
        ...

        """

        ## check for attributes
        if not self.newick:
            raise IPyradError("baba plot requires a newick treefile")
        if not self.tests:
            raise IPyradError("baba plot must have a .tests attribute")

        ## ensure tests is a list
        if isinstance(self.tests, dict):
            self.tests = [self.tests]

        # re-decompose the tree
        ttree = toytree.tree(self.newick)

        # subset test to show fewer
        if subset_tests is not None:
            #tests = self.tests[subset_tests]
            tests = [self.tests[i] for i in subset_tests]
            boots = self.results_boots[subset_tests]
        else:
            tests = self.tests
            boots = self.results_boots

        ## if prune tree
        if prune_tree_to_tests:
            alltesttaxa = set(itertools.chain(*self.taxon_table.values[0]))
            ttree = ttree.drop_tips(
                [i for i in ttree.get_tip_labels() if i not in alltesttaxa])
            ttree.tree.ladderize()

        ## make the plot
        canvas, axes, panel = baba_panel_plot(
            ttree=ttree,
            tests=tests,
            boots=boots,
            show_test_labels=show_test_labels,
            use_edge_lengths=use_edge_lengths,
            collapse_outgroup=collapse_outgroup,
            pct_tree_x=pct_tree_x,
            pct_tree_y=pct_tree_y,
            *args,
            **kwargs)
        return canvas, axes, panel