def store_positions(data, client_name):
    data = pickle.loads(data)
    chrom = data['CHROM']
    positions = data['POS']
    dsetname = "{}/positions".format(chrom)
    write_or_replace(store, dsetname, positions, np.uint32)
    logging.info(f"{client_name} has {len(positions)} loci in chromosome {chrom}.")
    def update_pval(self, message):
        def _share_new_ll():
            print(self.chroms)
            if len(self.chroms) >= 1:
                new_model = self.chroms.pop()
                self.send_coef(new_model, {})
            else:
                logging.info("P-value computation is Done!")
                logging.info(f"Computing P-values took roughly {time.time() - self.time:.1f} seconds.")

        message = pickle.loads(message)
        model = message["Estimated"]
        if model == "Small":
            _share_new_ll()
            return
        val = message["estimate"]
        if model in self.likelihood:
            prev = self.likelihood[model]
            if prev[1] == 1:
                ell = prev[0] + val
                pval = chi2sf(-2*ell, 1)
                write_or_replace(store, f"meta/{model}/ell", ell)
                write_or_replace(store, f"meta/{model}/pval", pval)
                del self.likelihood[model]
                _share_new_ll()
            else:
                self.likelihood[model] = [prev[0] + val, prev[1] - 1]
        else:
            self.likelihood[model] = [val, self.nconnections - 1]
 def collect_likelihoods(self, data):
     data = pickle.loads(data)
     model = data["estimated"]
     if model in self.scratch_likelihoods:
         self.scratch_likelihoods[model] += data['v']
         if self.linesearch_iter[model] == 1:
             if self.finished[model]:
                 write_or_replace(store, f"meta/{model}/newton_ell",
                                  self.scratch_likelihoods[model])
                 write_or_replace(store, f"meta/{model}/newton_pval",
                                  chi2sf(-2*self.scratch_likelihoods[model], 1))
                 del self.scratch_likelihoods[model]
                 del self.Hess[model], self.Gradients[model]
                 del self.Diags[model], self.fchanges[model]
                 if all(value for value in self.finished.values()):
                     manhattan_plot(storePath, "manhattan_plot.png")
                     logging.info("P-value computation is Done!")
                     logging.info(f"Computing P-values took roughly {time.time() - self.time:.1f} seconds.")
                     store.close()
             else:
                 self.newton_test_new_point(model)
     else:
         self.scratch_likelihoods[model] = data['v']
         self.linesearch_iter[model] = self.nconnections
     self.linesearch_iter[model] -= 1
Exemple #4
0
 def replace_dataset(tokeep, dset_name, return_deleted=False):
     vals = group[dset_name].value
     remaining = vals[tokeep]
     deleted = vals[np.logical_not(tokeep)]
     write_or_replace(group, dset_name, remaining)
     if return_deleted:
         return deleted
def store_filtered(message, client_config):
    pfile = client_config["plinkfile"]
    msg = pickle.loads(message)
    with h5py.File(shared.get_plink_store(pfile), 'a') as store:
        for chrom, val in msg.items():
            mask = store[f"{chrom}/PCA_mask"].value
            mask[mask] = val
            write_or_replace(store[f"{chrom}"], 'PCA_mask', val=mask)
    def update_estimate(self, z_hat, model, data):
        if model in self.estimates:
            if self.iters[model] >= self.max_iters:  # this shouldn't happen but it does! WHy?
                logging.info(f"WHYYYYY {model}, {self.iters[model]}")
                return
            prev = self.estimates[model]
            # self.normalization_stats["data"] = np.vstack((self.normalization_stats["data"], data["cov"]))
            if prev[1] == 1:
                beta = (prev[0] + z_hat)/(self.nconnections)
                self.iters[model] += 1
                if not self.iters[model] % 10:
                    logging.info(f"Finished iteration {self.iters[model]} on {model}")
                    if model == "Small":
                        logging.info(f"{np.sum(np.abs(beta - self.beta))}")
                    else:
                        logging.info(f"{np.linalg.norm(beta[7,:] - self.beta)}")
                if model == "Small":
                    self.beta = beta
                else:
                    self.beta = beta[7, :]

                if self.iters[model] == self.max_iters:  # #TODO this shouldn't happen but why does it?
                    write_or_replace(store, f"meta/{model}/coef", beta)
                    if model != "Small":
                        del self.estimates[model]
                    self.active_chroms.remove(model)
                    if model == "Small":
                        self.max_iters = 15
                        self.estimates["Small"] = beta
                    # chroms = [key for key in store if key != 'meta']
                    # if len(self.finished) == len(chroms) + 1:
                    self.finished[model] = True
                    if not self.chroms:
                        self.set_clients_state("ASSO_DONE")
                        logging.info(f"We are done with association!")
                        self.initialize_pval_computation()
                    else:
                        chrom = self.chroms.pop()
                        self.active_chroms.append(chrom)
                        self.make_chrom_active(chrom)
                    return
                else:
                    self.estimates[model] = [np.zeros_like(beta), self.nconnections]
                    msg = {"Estimated": model, "VALS": beta, "Iter": self.iters[model]}
                    self.send_request(msg, "estimate")
            else:
                self.estimates[model] = [prev[0] + z_hat, prev[1] - 1]

        else:  # Not in dictionary yet
            # self.normalization_stats["data"] = data["cov"]

            self.estimates[model] = [z_hat, self.nconnections - 1]
            self.iters[model] = 1
Exemple #7
0
def init_stats(msg_dict, client_config, env):
    print(msg_dict.keys())
    # Wait on previous tasks to finish
    i = current_app.control.inspect()
    client_name = client_config['name']
    while i.active() is not None:
        active_tasks = i.active()[f'celery@{client_name}']
        dependent_tasks = list(
            filter(lambda x: x['name'] == 'tasks.init_store', active_tasks))
        if len(dependent_tasks) > 0:
            logger.info('Waiting on tasks.init_store to finish.')
            time.sleep(.1)
        else:
            break
    #message = pickle.loads(message)
    pfile = client_config['plinkfile']
    #chrom = message["CHROM"]
    with h5py.File(shared.get_plink_store(pfile), 'a') as store:
        for chrom, message in msg_dict.items():
            logger.info(f'Computing statistics for Chrom: {chrom}.')
            chrom_group = store[chrom]
            if "MISS" in message:
                vals = np.array(message["MISS"])
                task = "not_missing_per_snp"
                write_or_replace(chrom_group, task, val=1 - vals)
            if "AF" in message:
                vals = np.array(message["AF"])
                task = 'MAF'
                write_or_replace(chrom_group, task, val=vals)
            if "HWE" in message:
                vals = np.array(message["HWE"])
                task = "hwe"
                write_or_replace(chrom_group, task, val=vals)
            if "VAR" in message:
                vals = np.array(message["VAR"])
                task = "VAR"
                write_or_replace(chrom_group, task, val=vals)
        logging.info(f'Finished initializing QC statistics for chrom {chrom}.')

    client_name = client_config['name']
    status = f'Finished with init stats.'
    networking.respond_to_server(
        f'api/clients/{client_name}/report?status={status}', 'POST', env=env)
def pca_projection(data, client_config):
    message = pickle.loads(data)
    inv_sigma = message["ISIG"]
    v = message["V"]
    chroms = message["CHROMS"]
    pfile = shared.get_plink_store(client_config["plinkfile"])
    with h5py.File(pfile, 'a') as store:
        n = 0
        for chrom in chroms:
            n += np.sum(store[f"{chrom}/PCA_mask"])
        num_inds = store.attrs["n"]
        # pca_sigma = dset.require_dataset('pca_sigma', shape=inv_sigma.shape, dtype=np.float32)
        # pca_sigma[:] = inv_sigma
        arr = np.empty((num_inds, n), dtype=np.float32)
        offset = 0
        for chrom in chroms:
            group = store[str(chrom)]
            tokeep = group["PCA_mask"].value
            af = group["MAF"].value[tokeep]
            sd = np.sqrt(group["VAR"].value[tokeep])
            positions = group["positions"].value[tokeep]
            for i, position in enumerate(positions):
                val = (group[str(position)].value - 2 * af[i])/sd[i]
                val[np.isnan(val)] = 0
                arr[:, offset+i] = val
            offset += i+1
        u = arr.dot(v.T).dot(np.diag(inv_sigma))
        u, v = svd_flip(u, v, u_based_decision=False)
        dset = store.require_group("pca")
        write_or_replace(dset, 'pca_sigma', val=inv_sigma)
        write_or_replace(dset, 'pca_v.T', val=v)
        write_or_replace(dset, 'pca_u', val=u)
        # pca_vt = dset.require_dataset('pca_v.T', shape=v.shape,
        #    dtype=np.float32)
        # pca_vt[:,:] = v
        # pca_u = dset.require_dataset('pca_u', shape=u.shape,
        #    dtype=np.float32)
        # pca_u[:,:] = u
    logger.info("Done with projection!")
def plinkToH5(client_config, env):
    """Gets plink prefix, produces an HDF file with the same prefix"""
    pfile = client_config['plinkfile']
    store_name = shared.get_plink_store(pfile)
    logger.info(f'Opening plinkfile: {pfile}')
    try:
        plink_file = plinkfile.open(pfile)
    except MemoryError as e:
        logger.error('MemoryError!')
        logger.error(e)
    if not plink_file.one_locus_per_row():
        logger.error("""This script requires that snps are
            rows and samples columns.""")
        sys.exit(1)
    sample_list = plink_file.get_samples()
    locus_list = plink_file.get_loci()
    n_tot = len(sample_list)
    logger.info(f'Opening h5py file:{store_name}')
    with h5py.File(store_name, 'w', libver='latest') as store:
        store.attrs['n'] = len(sample_list)
        store.attrs['has_local_AF'] = False
        store.attrs['has_global_AF'] = False
        store.attrs['has_centering'] = False
        store.attrs['has_normalization'] = False
        potential_pheno_file = pfile + ".pheno"
        if os.path.isfile(pfile + ".pheno"):
            affection = np.loadtxt(potential_pheno_file, dtype=int, usecols=2)
        else:
            affection = [sample.affection for sample in sample_list]
        if len(np.unique(affection)) > 2:
            raise ValueError(
                "phenotype is not binary. We only support binary for now")
        write_or_replace(store, 'meta/Status', affection, np.int8)
        ids = [sample.iid for sample in sample_list]
        write_or_replace(store, 'meta/id', ids, 'S11')
        del ids, affection
        # Read Demographic file
        logger.info(f'Reading demographic file at {pfile}.ind')
        logger.info(f'File exists: {os.path.isfile(pfile + ".ind")}')
        with open(pfile + ".ind", 'r') as dem_f:
            dem = [(row.split("\t")[2]).encode("UTF8") for row in dem_f]
            write_or_replace(store, 'meta/regions', dem)
        # Read chromosome data
        current_chr = 1
        positions = []
        rsids = []
        all_counts = []
        current_group = store.require_group(str(current_chr))
        genotypes = np.zeros(n_tot, dtype=np.float32)
        for locus, row in zip(locus_list, plink_file):
            if locus.chromosome != current_chr:
                if len(positions) == 0:
                    del store[str(current_chr)]
                else:
                    write_or_replace(current_group,
                                     'positions',
                                     positions,
                                     dtype=np.uint)
                    write_or_replace(current_group, 'rsids', rsids)
                    write_or_replace(current_group, 'counts', all_counts,
                                     np.uint32)

                    send_positions_to_server(positions, current_chr,
                                             client_config, env)
                    positions = []
                    # rsid = []
                    all_counts = []
                current_chr = locus.chromosome
                if current_chr == 23:
                    break
                current_group = store.require_group(str(current_chr))
            pos = str(locus.bp_position)
            counts, geno = process_plink_row(row, genotypes)
            # This should be a try except
            try:
                current_group.create_dataset(pos, data=geno)
            except Exception:
                logger.error(
                    f"Cannot write position: chr{locus.chromosome} {pos}")
            rsids.append(locus.name.encode('utf8'))
            positions.append(pos)
            all_counts.append(counts)
        if locus.chromosome != 23:
            write_or_replace(current_group, 'positions', positions, np.uint32)
            write_or_replace(current_group, 'rsids', rsids)
            write_or_replace(current_group, 'counts', all_counts, np.uint32)
            send_positions_to_server(positions, current_chr, client_config,
                                     env)
    plink_file.close()
    logger.info('Finished writing plink to hdf5.')
Exemple #10
0
 def newton_iter(self, model):
     TOL = 1e-7
     self.t = 1
     hs = self.Hess[model]
     gs = self.Gradients[model]
     ds = self.Diags[model]
     dfxs = np.zeros((gs.shape[0], 1))
     af = store[f"{model}/allele_freq"].value
     x0 = self.estimates[model]
     if model not in self.converged:
         self.iters[model] = 1
         convergence_status = np.zeros((gs.shape[0], 1), dtype=bool)
         self.converged[model] = convergence_status.copy()
         old_not_converged = None
     else:
         self.iters[model] += 1
         old_not_converged = np.logical_not(self.converged[model])
         af = af[old_not_converged[:, 0]]
         x0 = x0[old_not_converged[:, 0]]
         convergence_status = np.zeros((x0.shape[0], 1), dtype=bool)
     logging.info(f"Now updating coefficient estimates for chr{model} iteration: {self.iters[model]}")
     for i, g in enumerate(gs):
         if af[i] < self.threshold or 1-af[i] < self.threshold:
             convergence_status[i] = True
             continue
         h = np.diagflat(ds[i])
         if i % 2:
             hessian = np.triu(hs[i//2])
         else:
             hessian = np.tril(hs[i//2])
         hessian += hessian.T
         hessian += h
         dx = lstsq(-hessian, g[:, np.newaxis], 1e-5, signature='ddd->ddid')[0]
         dfx = dot(g.T, dx)
         if np.abs(dfx) < TOL:
             x0[i, :] += dx
             convergence_status[i] = True
             continue
         gs[i:i+1, :] = x0[i].T + self.t*dx.T
         dfxs[i, :] = dfx
         ds[i] = dx[:, 0]
     if old_not_converged is not None:
         self.estimates[model][old_not_converged[:, 0]] = x0
     else:
         self.estimates[model] = x0
     if np.prod(convergence_status) or self.iters[model] == self.max_iters:
         self.finished[model] = True
         write_or_replace(store, f"meta/{model}/newton_coef", self.estimates[model])
         af = store[f"{model}/allele_freq"].value
         arr = np.logical_not(np.logical_or(af < self.threshold, 1-af < self.threshold))
         est = self.estimates[model][arr]
         msg = {"Estimated": model, "conv": np.expand_dims(arr, axis=1), "x0": est[:, :, 0]}
         self.send_request(msg, "query")
         del self.estimates[model]
         self.active_chroms.remove(model)
         if not self.chroms:
             self.set_clients_state("ASSO_DONE")
             logging.info("We are done with association.")
             logging.info(f"Computing coefficients took roughly {time.time() - self.time:.1f} seconds.")
             self.time = time.time()
             # COMPUTE PVALS
         else:
             chrom = self.chroms.pop()
             self.active_chroms.append(chrom)
             self.make_chrom_active(chrom)
         return
     else:
         if self.iters[model] > 1:
             temp = self.converged[model]
             temp[np.logical_not(temp)] = convergence_status[:, 0]
             self.converged[model] = temp
         else:
             self.converged[model] = convergence_status
         gs = gs[np.logical_not(convergence_status)[:, 0], :]
         msg = {"Estimated": model, "x0": gs, "conv": np.logical_not(self.converged[model])}
         self.send_request(msg, "query")
     self.linesearch_convergence[model] = convergence_status.copy()
     del self.Hess[model], self.Gradients[model]
     self.Diags[model] = ds[np.logical_not(convergence_status)[:, 0], :]  # repurposed
     # self.fchanges[model] = dfxs[np.logical_not(convergence_status)]
     self.fchanges[model] = dfxs
Exemple #11
0
def run_QC(filters, client_config, prefix, remove=True, env="production"):
    def find_what_passes(qc_name, dset_name, tokeep, doubleSided=False):
        vals = group[dset_name].value
        if qc_name in filters:
            thresh = float(filters[qc_name])
            if not doubleSided:
                tokeep = np.logical_and(tokeep, vals > thresh)
            else:
                tokeep = np.logical_and(
                    tokeep,
                    np.logical_and(
                        vals > thresh - Settings.kSmallEpsilon,
                        (1.0 - vals) > thresh - Settings.kSmallEpsilon))
        return tokeep

    def replace_dataset(tokeep, dset_name, return_deleted=False):
        vals = group[dset_name].value
        remaining = vals[tokeep]
        deleted = vals[np.logical_not(tokeep)]
        write_or_replace(group, dset_name, remaining)
        if return_deleted:
            return deleted

    pfile = client_config["plinkfile"]
    store_name = shared.get_plink_store(pfile)
    with h5py.File(store_name, 'a') as store:
        for chrom in store.keys():
            if chrom == "meta":
                continue
            group = store[chrom]
            positions = group['positions'].value
            if "QC_mask" in group:
                tokeep = group["QC_mask"].value
            else:
                tokeep = np.ones_like(positions, dtype=bool)

            tokeep = find_what_passes(QCFilterNames.QC_HWE, "hwe", tokeep)
            tokeep = find_what_passes(QCFilterNames.QC_MAF,
                                      "MAF",
                                      tokeep,
                                      doubleSided=True)
            if QCFilterNames.QC_MPS in filters:
                filters[
                    QCFilterNames.QC_MPS] = 1 - filters[QCFilterNames.QC_MPS]
            tokeep = find_what_passes(QCFilterNames.QC_MPS,
                                      "not_missing_per_snp", tokeep)
            logger.info(
                f"After filtering {chrom}, {np.sum(tokeep)} snps remain")
            if remove:  # Delete what doesn't pass
                replace_dataset(tokeep, 'hwe')
                replace_dataset(tokeep, 'VAR')
                replace_dataset(tokeep, 'MAF')
                replace_dataset(tokeep, 'not_missing_per_snp')
                deleted = replace_dataset(tokeep,
                                          'positions',
                                          return_deleted=True)
                for snp in deleted:
                    snp = str(snp)
                    if snp in group:
                        del group[snp]
            else:  # Store what has been tagged
                pass_mask = prefix + "_mask"
                pos_mask = prefix + "_positions"
                if pass_mask in group:
                    del group[pass_mask]
                if pos_mask in group:
                    del group[pos_mask]
                write_or_replace(group, pass_mask, val=tokeep, dtype=bool)
                positions = group['positions'].value[tokeep]
                write_or_replace(group, pos_mask, val=positions)
                if prefix == "PCA":
                    write_or_replace(group,
                                     "PCA_passed",
                                     val=np.ones(np.sum(tokeep), dtype=bool))
                    if 'non_ld_mask' in group:
                        del group['non_ld_mask']
    client_name = client_config['name']
    if prefix == "QC":
        networking.respond_to_server('api/tasks/QC/FIN', "POST", b'',
                                     client_name, env)
    else:
        networking.respond_to_server('api/tasks/PCA/FIN', "POST", b'',
                                     client_name, env)