예제 #1
0
def work_item2(pheno, G_kernel, spatial_coor, spatial_iid, alpha, alpha_power,
               xxx_todo_changeme, xxx_todo_changeme1, xxx_todo_changeme2,
               just_testing, do_uncorr, do_gxe2, a2):

    #########################################
    # Load GPS info from filename if that's the way it is given
    ########################################
    (jackknife_index, jackknife_count, jackknife_seed) = xxx_todo_changeme
    (permute_plus_index, permute_plus_count,
     permute_plus_seed) = xxx_todo_changeme1
    (permute_times_index, permute_times_count,
     permute_times_seed) = xxx_todo_changeme2
    if isinstance(spatial_coor, str):
        assert spatial_iid is None, "if spatial_coor is a str, then spatial_iid should be None"
        gps_table = pd.read_csv(spatial_coor, delimiter=" ").dropna()
        spatial_iid = np.array([(v, v) for v in gps_table["id"].values])
        spatial_coor = gps_table[["south_new", "east_new"]].values

    #########################################
    # Remove any missing values from pheno
    ########################################
    assert pheno.sid_count == 1, "Expect only one pheno in work_item"
    pheno = pheno.read()
    pheno = pheno[pheno.val[:, 0] == pheno.
                  val[:, 0], :]  #Excludes NaN because NaN is not equal to NaN

    #########################################
    # Environment: Turn spatial info info a KernelData
    #########################################
    spatial_val = spatial_similarity(spatial_coor, alpha, power=alpha_power)
    E_kernel = KernelData(iid=spatial_iid, val=spatial_val)

    #########################################
    # Intersect, apply the jackknife or permutation, and then (because we now know the iids) standardize appropriately
    #########################################
    from pysnptools.util import intersect_apply
    G_kernel, E_kernel, pheno = intersect_apply([G_kernel, E_kernel, pheno])

    if jackknife_index >= 0:
        assert jackknife_count <= G_kernel.iid_count, "expect the number of groups to be less than the number of iids"
        assert jackknife_index < jackknife_count, "expect the jackknife index to be less than the count"
        m_fold = model_selection.KFold(n_splits=jackknife_count,
                                       shuffle=True,
                                       random_state=jackknife_seed %
                                       4294967295).split(
                                           list(range(G_kernel.iid_count)))
        iid_index, _ = _nth(m_fold, jackknife_index)
        pheno = pheno[iid_index, :]
        G_kernel = G_kernel[iid_index]
        E_kernel = E_kernel[iid_index]

    if permute_plus_index >= 0:
        #We shuffle the val, but not the iid, because that would cancel out.
        #Integrate the permute_plus_index into the random.
        np.random.seed((permute_plus_seed + permute_plus_index) % 4294967295)
        new_index = np.arange(G_kernel.iid_count)
        np.random.shuffle(new_index)
        E_kernel_temp = E_kernel[new_index].read()
        E_kernel = KernelData(
            iid=E_kernel.iid,
            val=E_kernel_temp.val,
            name="permutation {0}".format(permute_plus_index))

    pheno = pheno.read().standardize()  # defaults to Unit standardize
    G_kernel = G_kernel.read().standardize(
    )  # defaults to DiagKtoN standardize
    E_kernel = E_kernel.read().standardize(
    )  # defaults to DiagKtoN standardize

    #########################################
    # find h2uncoor, the best mixing weight of pure random noise and G_kernel
    #########################################

    if not do_uncorr:
        h2uncorr, nLLuncorr = np.nan, np.nan
    else:
        logging.info("Find best h2 for G_kernel")
        lmmg = LMM()
        lmmg.setK(K0=G_kernel.val)
        lmmg.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmmg.sety(pheno.val[:, 0])
        if not just_testing:
            resg = lmmg.findH2()
            h2uncorr, nLLuncorr = resg["h2"], resg["nLL"]
        else:
            h2uncorr, nLLuncorr = 0, 0
        logging.info("just G: h2uncorr: {0}, nLLuncorr: {1}".format(
            h2uncorr, nLLuncorr))

    #########################################
    # Find a2, the best mixing for G_kernel and E_kernel
    #########################################

    if a2 is None:
        logging.info("Find best mixing for G_kernel and E_kernel")
        lmm1 = LMM()
        lmm1.setK(K0=G_kernel.val, K1=E_kernel.val, a2=0.5)
        lmm1.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmm1.sety(pheno.val[:, 0])
        if not just_testing:
            res1 = lmm1.findA2()
            h2, a2, nLLcorr = res1["h2"], res1["a2"], res1["nLL"]
            h2corr = h2 * (1 - a2)
            e2 = h2 * a2
            h2corr_raw = h2
        else:
            h2corr, e2, a2, nLLcorr, h2corr_raw = 0, 0, .5, 0, 0
        logging.info(
            "G plus E mixture: h2corr: {0}, e2: {1}, a2: {2}, nLLcorr: {3} (h2corr_raw:{4})"
            .format(h2corr, e2, a2, nLLcorr, h2corr_raw))
    else:
        h2corr, e2, nLLcorr, h2corr_raw = np.nan, np.nan, np.nan, np.nan

    #########################################
    # Find a2_gxe2, the best mixing for G+E_kernel and the GxE kernel
    #########################################

    if not do_gxe2:
        gxe2, a2_gxe2, nLL_gxe2 = np.nan, np.nan, np.nan
    else:
        #Create the G+E kernel by mixing according to a2
        val = (1 - a2) * G_kernel.val + a2 * E_kernel.val
        GplusE_kernel = KernelData(iid=G_kernel.iid,
                                   val=val,
                                   name="{0} G + {1} E".format(1 - a2, a2))
        #Don't need to standardize GplusE_kernel because it's the weighted combination of standardized kernels

        # Create GxE Kernel and then find the best mixing of it and GplusE
        logging.info("Find best mixing for GxE and GplusE_kernel")

        val = G_kernel.val * E_kernel.val
        if permute_times_index >= 0:
            #We shuffle the val, but not the iid, because doing both would cancel out
            np.random.seed(
                (permute_times_seed + permute_times_index) % 4294967295)
            new_index = np.arange(G_kernel.iid_count)
            np.random.shuffle(new_index)
            val = pstutil.sub_matrix(val, new_index, new_index)

        GxE_kernel = KernelData(
            iid=G_kernel.iid, val=val, name="GxE"
        )  # recall that Python '*' is just element-wise multiplication
        GxE_kernel = GxE_kernel.standardize()

        lmm2 = LMM()
        lmm2.setK(K0=GplusE_kernel.val, K1=GxE_kernel.val, a2=0.5)
        lmm2.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmm2.sety(pheno.val[:, 0])
        if not just_testing:
            res2 = lmm2.findA2()
            gxe2, a2_gxe2, nLL_gxe2 = res2["h2"], res2["a2"], res2["nLL"]
            gxe2 *= a2_gxe2
        else:
            gxe2, a2_gxe2, nLL_gxe2 = 0, .5, 0
        logging.info(
            "G+E plus GxE mixture: gxe2: {0}, a2_gxe2: {1}, nLL_gxe2: {2}".
            format(gxe2, a2_gxe2, nLL_gxe2))

    #########################################
    # Return results
    #########################################

    ret = {
        "h2uncorr": h2uncorr,
        "nLLuncorr": nLLuncorr,
        "h2corr": h2corr,
        "h2corr_raw": h2corr_raw,
        "e2": e2,
        "a2": a2,
        "nLLcorr": nLLcorr,
        "gxe2": gxe2,
        "a2_gxe2": a2_gxe2,
        "nLL_gxe2": nLL_gxe2,
        "alpha": alpha,
        "alpha_power": alpha_power,
        "phen": np.array(pheno.sid, dtype='str')[0],
        "jackknife_index": jackknife_index,
        "jackknife_count": jackknife_count,
        "jackknife_seed": jackknife_seed,
        "permute_plus_index": permute_plus_index,
        "permute_plus_count": permute_plus_count,
        "permute_plus_seed": permute_plus_seed,
        "permute_times_index": permute_times_index,
        "permute_times_count": permute_times_count,
        "permute_times_seed": permute_times_seed
    }

    logging.info("run_line: {0}".format(ret))
    return ret
    )  # defaults to DiagKtoN standardize
    E_kernel = E_kernel.read().standardize(
    )  # defaults to DiagKtoN standardize

    #########################################
    # find h2uncoor, the best mixing weight of pure random noise and G_kernel
    #########################################

    if not do_uncorr:
        h2uncorr, nLLuncorr = np.nan, np.nan
    else:
        logging.info("Find best h2 for G_kernel")
        lmmg = LMM()
        lmmg.setK(K0=G_kernel.val)
        lmmg.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmmg.sety(pheno.val[:, 0])
        if not just_testing:
            resg = lmmg.findH2()
            h2uncorr, nLLuncorr = resg["h2"], resg["nLL"]
        else:
            h2uncorr, nLLuncorr = 0, 0
        logging.info("just G: h2uncorr: {0}, nLLuncorr: {1}".format(
            h2uncorr, nLLuncorr))

    #########################################
    # Find a2, the best mixing for G_kernel and E_kernel
    #########################################

    if a2 is None:
        logging.info("Find best mixing for G_kernel and E_kernel")
        lmm1 = LMM()
예제 #3
0
    def fit(self, X=None, y=None, K0_train=None, K1_train=None, h2=None, mixing=None):
        """
        Method for training a :class:`FastLMM` predictor. If the examples in X, y, K0_train, K1_train are not the same, they will be reordered and intersected.

        :param X: training covariate information, optional: 
          If you give a string, it should be the file name of a PLINK phenotype-formatted file.
        :type X: a PySnpTools :class:`SnpReader` (such as :class:`Pheno` or :class:`SnpData`) or string.

        :param y: training phenotype:
          If you give a string, it should be the file name of a PLINK phenotype-formatted file.
        :type y: a PySnpTools :class:`SnpReader` (such as :class:`Pheno` or :class:`SnpData`) or string.

        :param K0_train: A similarity matrix or SNPs from which to construct such a similarity matrix.
               Can be any :class:`.SnpReader`. If you give a string, can be the name of a PLINK-formated Bed file.
               Can be PySnpTools :class:`.KernelReader`. If you give a string it can be the name of a :class:`.KernelNpz` file.
        :type K0_train: :class:`.SnpReader` or a string or :class:`.KernelReader`

        :param K1_train: A second similarity matrix or SNPs from which to construct such a second similarity matrix. (Also, see 'mixing').
               Can be any :class:`.SnpReader`. If you give a string, can be the name of a PLINK-formated Bed file.
               Can be PySnpTools :class:`.KernelReader`. If you give a string it can be the name of a :class:`.KernelNpz` file.
        :type K1_train: :class:`.SnpReader` or a string or :class:`.KernelReader`

        :param h2: A parameter to LMM learning that tells how much weight to give the K's vs. the identity matrix, optional
                If not given will search for best value.
                If mixing is unspecified, then h2 must also be unspecified.
        :type h2: number

        :param mixing: Weight between 0.0 (inclusive, default) and 1.0 (inclusive) given to K1_train relative to K0_train.
                If you give no mixing number and a K1_train is given, the best weight will be learned.
        :type mixing: number


        :rtype: self, the fitted FastLMM predictor
        """
        self.is_fitted = True
        # should this have a cache file like 'single_snp'?
        #!!!later what happens if missing values in pheno_train?
        #!!!later add code so that X, y, etc can be array-like objects without iid information. In that case, make up iid info

        assert y is not None, "y must be given"

        y = _pheno_fixup(y)
        assert y.sid_count == 1, "Expect y to be just one variable"
        X = _pheno_fixup(X, iid_if_none=y.iid)

        K0_train = _kernel_fixup(K0_train, iid_if_none=y.iid, standardizer=self.snp_standardizer)
        K1_train = _kernel_fixup(K1_train, iid_if_none=y.iid, standardizer=self.snp_standardizer)

        K0_train, K1_train, X, y = intersect_apply([K0_train, K1_train, X, y],intersect_before_standardize=True) #!!! test this on both K's as None
        from fastlmm.association.single_snp import _set_block_size
        K0_train, K1_train, block_size = _set_block_size(K0_train, K1_train, mixing, self.GB_goal, self.force_full_rank, self.force_low_rank)

        X = X.read()
        # If possible, unit standardize train and test together. If that is not possible, unit standardize only train and later apply
        # the same linear transformation to test. Unit standardization is necessary for FastLMM to work correctly.
        #!!!later is the calculation of the training data's stats done twice???
        X, covar_unit_trained = X.standardize(self.covariate_standardizer,block_size=block_size,return_trained=True) #This also fills missing with the mean

        # add a column of 1's to cov to increase DOF of model (and accuracy) by allowing a constant offset
        X = SnpData(iid=X.iid,
                                sid=self._new_snp_name(X),
                                val=np.c_[X.val,np.ones((X.iid_count,1))],
                                name ="covariate_train w/ 1's")

        y0 =  y.read().val #!!!later would view_ok=True,order='A' be ok because this code already did a fresh read to look for any missing values 

        from fastlmm.association.single_snp import _Mixer #!!!move _combine_the_best_way to another file (e.g. this one)
        K_train, h2, mixer = _Mixer.combine_the_best_way(K0_train,K1_train,X.val,y0,mixing,h2,force_full_rank=self.force_full_rank,force_low_rank=self.force_low_rank,kernel_standardizer=self.kernel_standardizer,block_size=block_size)

        # do final prediction using lmm.py
        lmm = LMM()

        #Special case: The K kernel is defined implicitly with SNP data
        if mixer.do_g:
            assert isinstance(K_train.standardizer,StandardizerIdentity), "Expect Identity standardizer"
            G_train = K_train.snpreader
            lmm.setG(G0=K_train.snpreader.val)
        else:
            lmm.setK(K0=K_train.val)

        lmm.setX(X.val)
        lmm.sety(y0[:,0])

        # Find the best h2 and also on covariates (not given from new model)
        if h2 is None:
            res = lmm.findH2() #!!!why is REML true in the return???
        else:
            res = lmm.nLLeval(h2=h2)


        #We compute sigma2 instead of using res['sigma2'] because res['sigma2'] is only the pure noise.
        full_sigma2 = float(sum((np.dot(X.val,res['beta']).reshape(-1,1)-y0)**2))/y.iid_count #!!! this is non REML. Is that right?

        ###### all references to 'fastlmm_model' should be here so that we don't forget any
        self.block_size = block_size
        self.beta = res['beta']
        self.h2 = res['h2']
        self.sigma2 = full_sigma2
        self.U = lmm.U
        self.S = lmm.S
        self.K = lmm.K
        self.G = lmm.G
        self.y = lmm.y
        self.Uy = lmm.Uy
        self.X = lmm.X
        self.UX = lmm.UX
        self.mixer = mixer
        self.covar_unit_trained = covar_unit_trained
        self.K_train_iid = K_train.iid
        self.covar_sid = X.sid
        self.pheno_sid = y.sid
        self.G0_train = K0_train.snpreader if isinstance(K0_train,SnpKernel) else None #!!!later expensive?
        self.G1_train = K1_train.snpreader if isinstance(K1_train,SnpKernel) else None #!!!later expensive?
        return self
예제 #4
0
    def fit(self,
            X=None,
            y=None,
            K0_train=None,
            K1_train=None,
            h2raw=None,
            mixing=None,
            count_A1=None):  #!!!is this h2 or h2corr????
        """
        Method for training a :class:`FastLMM` predictor. If the examples in X, y, K0_train, K1_train are not the same, they will be reordered and intersected.

        :param X: training covariate information, optional: 
          If you give a string, it should be the file name of a PLINK phenotype-formatted file.
        :type X: a PySnpTools `SnpReader <http://fastlmm.github.io/PySnpTools/#snpreader-snpreader>`__
          (such as `Pheno <http://fastlmm.github.io/PySnpTools/#snpreader-pheno>`__ or `SnpData <http://fastlmm.github.io/PySnpTools/#snpreader-snpdata>`__) or string.

        :param y: training phenotype:
          If you give a string, it should be the file name of a PLINK phenotype-formatted file.
        :type y: a PySnpTools `SnpReader <http://fastlmm.github.io/PySnpTools/#snpreader-snpreader>`__ 
          (such as `Pheno <http://fastlmm.github.io/PySnpTools/#snpreader-pheno>`__ or `SnpData <http://fastlmm.github.io/PySnpTools/#snpreader-snpdata>`__) or string.

        :param K0_train: A similarity matrix or SNPs from which to construct such a similarity matrix.
               Can be any `SnpReader <http://fastlmm.github.io/PySnpTools/#snpreader-snpreader>`__.
               If you give a string, can be the name of a PLINK-formated Bed file.
               Can be PySnpTools `KernelReader <http://fastlmm.github.io/PySnpTools/#kernelreader-kernelreader>`__.
               If you give a string it can be the name of a `KernelNpz <http://fastlmm.github.io/PySnpTools/#kernelreader-kernelnpz>`__ file.
        :type K0_train: `SnpReader <http://fastlmm.github.io/PySnpTools/#snpreader-snpreader>`__ or a string or
               `KernelReader <http://fastlmm.github.io/PySnpTools/#kernelreader-kernelreader>`__

        :param K1_train: A second similarity matrix or SNPs from which to construct such a second similarity matrix. (Also, see 'mixing').
               Can be any `SnpReader <http://fastlmm.github.io/PySnpTools/#snpreader-snpreader>`__. If you give a string, can be the name of a PLINK-formated Bed file.
               Can be PySnpTools `KernelReader <http://fastlmm.github.io/PySnpTools/#kernelreader-kernelreader>`__.
               If you give a string it can be the name of a `KernelNpz <http://fastlmm.github.io/PySnpTools/#kernelreader-kernelnpz>`__ file.
        :type K1_train: `SnpReader <http://fastlmm.github.io/PySnpTools/#snpreader-snpreader>`__ or a string or
               `KernelReader <http://fastlmm.github.io/PySnpTools/#kernelreader-kernelreader>`__

        :param h2raw: A parameter to LMM learning that tells how much weight to give the K's vs. the identity matrix, optional 
                If not given will search for best value.
                If mixing is unspecified, then h2 must also be unspecified.
        :type h2raw: number

        :param mixing: Weight between 0.0 (inclusive, default) and 1.0 (inclusive) given to K1_train relative to K0_train.
                If you give no mixing number and a K1_train is given, the best weight will be learned.
        :type mixing: number

        :param count_A1: If it needs to read SNP data from a BED-formatted file, tells if it should count the number of A1
             alleles (the PLINK standard) or the number of A2 alleles. False is the current default, but in the future the default will change to True.
        :type count_A1: bool

        :rtype: self, the fitted FastLMM predictor
        """
        with patch.dict('os.environ', {'ARRAY_MODULE': 'numpy'}) as _:

            self.is_fitted = True
            # should this have a cache file like 'single_snp'?
            #!!!later what happens if missing values in pheno_train?
            #!!!later add code so that X, y, etc can be array-like objects without iid information. In that case, make up iid info

            assert y is not None, "y must be given"

            y = _pheno_fixup(y, count_A1=count_A1)
            assert y.sid_count == 1, "Expect y to be just one variable"
            X = _pheno_fixup(X, iid_if_none=y.iid, count_A1=count_A1)

            K0_train = _kernel_fixup(K0_train,
                                     iid_if_none=y.iid,
                                     standardizer=self.snp_standardizer,
                                     count_A1=count_A1)
            K1_train = _kernel_fixup(K1_train,
                                     iid_if_none=y.iid,
                                     standardizer=self.snp_standardizer,
                                     count_A1=count_A1)

            K0_train, K1_train, X, y = intersect_apply(
                [K0_train, K1_train, X, y], intersect_before_standardize=True
            )  #!!! test this on both K's as None
            from fastlmm.association.single_snp import _set_block_size
            K0_train, K1_train, block_size = _set_block_size(
                K0_train, K1_train, mixing, self.GB_goal, self.force_full_rank,
                self.force_low_rank)

            X = X.read()
            # If possible, unit standardize train and test together. If that is not possible, unit standardize only train and later apply
            # the same linear transformation to test. Unit standardization is necessary for FastLMM to work correctly.
            #!!!later is the calculation of the training data's stats done twice???
            X, covar_unit_trained = X.standardize(
                self.covariate_standardizer,
                block_size=block_size,
                return_trained=True)  #This also fills missing with the mean

            # add a column of 1's to cov to increase DOF of model (and accuracy) by allowing a constant offset
            X = SnpData(iid=X.iid,
                        sid=self._new_snp_name(X),
                        val=np.c_[X.val, np.ones((X.iid_count, 1))],
                        name="covariate_train w/ 1's")

            y0 = y.read(
            ).val  #!!!later would view_ok=True,order='A' be ok because this code already did a fresh read to look for any missing values

            from fastlmm.association.single_snp import _Mixer  #!!!move _combine_the_best_way to another file (e.g. this one)
            K_train, h2raw, mixer = _Mixer.combine_the_best_way(
                K0_train,
                K1_train,
                X.val,
                y0,
                mixing,
                h2raw,
                force_full_rank=self.force_full_rank,
                force_low_rank=self.force_low_rank,
                kernel_standardizer=self.kernel_standardizer,
                block_size=block_size)

            # do final prediction using lmm.py
            lmm = LMM()

            #Special case: The K kernel is defined implicitly with SNP data
            if mixer.do_g:
                assert isinstance(
                    K_train.standardizer,
                    StandardizerIdentity), "Expect Identity standardizer"
                G_train = K_train.snpreader
                lmm.setG(G0=K_train.snpreader.val)
            else:
                lmm.setK(K0=K_train.val)

            lmm.setX(X.val)
            lmm.sety(y0[:, 0])

            # Find the best h2 and also on covariates (not given from new model)
            if h2raw is None:
                res = lmm.findH2()  #!!!why is REML true in the return???
            else:
                res = lmm.nLLeval(h2=h2raw)

            #We compute sigma2 instead of using res['sigma2'] because res['sigma2'] is only the pure noise.
            full_sigma2 = float(
                sum((np.dot(X.val, res['beta']).reshape(-1, 1) - y0)**
                    2)) / y.iid_count  #!!! this is non REML. Is that right?

            ###### all references to 'fastlmm_model' should be here so that we don't forget any
            self.block_size = block_size
            self.beta = res['beta']
            self.h2raw = res['h2']
            self.sigma2 = full_sigma2
            self.U = lmm.U
            self.S = lmm.S
            self.K = lmm.K
            self.G = lmm.G
            self.y = lmm.y
            self.Uy = lmm.Uy
            self.X = lmm.X
            self.UX = lmm.UX
            self.mixer = mixer
            self.covar_unit_trained = covar_unit_trained
            self.K_train_iid = K_train.iid
            self.covar_sid = X.sid
            self.pheno_sid = y.sid
            self.G0_train = K0_train.snpreader if isinstance(
                K0_train, SnpKernel) else None  #!!!later expensive?
            self.G1_train = K1_train.snpreader if isinstance(
                K1_train, SnpKernel) else None  #!!!later expensive?
            return self
def work_item(arg_tuple):
    (
        pheno,
        G_kernel,
        spatial_coor,
        spatial_iid,
        alpha,
        alpha_power,  # The main inputs
        (jackknife_index, jackknife_count,
         jackknife_seed),  # Jackknifing and permutations inputs
        (permute_plus_index, permute_plus_count, permute_plus_seed),
        (permute_times_index, permute_times_count, permute_times_seed),
        just_testing,
        do_uncorr,
        do_gxe2,
        a2) = arg_tuple  # Shortcutting work

    #########################################
    # Remove any missing values from pheno
    #########################################
    pheno = pheno.read()
    pheno = pheno[pheno.val[:, 0] == pheno.
                  val[:, 0], :]  #Excludes NaN because NaN is not equal to NaN

    #########################################
    # Environment: Turn spatial info info a KernelData
    #########################################
    spatial_val = spatial_similarity(spatial_coor, alpha, power=alpha_power)
    E_kernel = KernelData(iid=spatial_iid, val=spatial_val)

    #########################################
    # Intersect, apply the jackknife or permutation, and then (because we now know the iids) standardize appropriately
    #########################################
    from pysnptools.util import intersect_apply
    G_kernel, E_kernel, pheno = intersect_apply([G_kernel, E_kernel, pheno])

    if jackknife_index >= 0:
        assert jackknife_count <= G_kernel.iid_count, "expect the number of groups to be less than the number of iids"
        assert jackknife_index < jackknife_count, "expect the jackknife index to be less than the count"
        m_fold = cross_validation.KFold(n=G_kernel.iid_count,
                                        n_folds=jackknife_count,
                                        shuffle=True,
                                        random_state=jackknife_seed %
                                        4294967295)
        iid_index, _ = _nth(m_fold, jackknife_index)
        pheno = pheno[iid_index, :]
        G_kernel = G_kernel[iid_index]
        E_kernel = E_kernel[iid_index]

    if permute_plus_index >= 0:
        #We shuffle the val, but not the iid, because that would cancel out.
        #Integrate the permute_plus_index into the random.
        np.random.seed((permute_plus_seed + permute_plus_index) % 4294967295)
        new_index = np.arange(G_kernel.iid_count)
        np.random.shuffle(new_index)
        E_kernel_temp = E_kernel[new_index].read()
        E_kernel = KernelData(
            iid=E_kernel.iid,
            val=E_kernel_temp.val,
            parent_string="permutation {0}".format(permute_plus_index))

    pheno = pheno.read().standardize()  # defaults to Unit standardize
    G_kernel = G_kernel.read().standardize(
    )  # defaults to DiagKtoN standardize
    E_kernel = E_kernel.read().standardize(
    )  # defaults to DiagKtoN standardize

    #########################################
    # find h2uncoor, the best mixing weight of pure random noise and G_kernel
    #########################################

    if not do_uncorr:
        h2uncorr, nLLuncorr = np.nan, np.nan
    else:
        logging.info("Find best h2 for G_kernel")
        lmmg = LMM()
        lmmg.setK(K0=G_kernel.val)
        lmmg.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmmg.sety(pheno.val[:, 0])
        if not just_testing:
            resg = lmmg.findH2()
            h2uncorr, nLLuncorr = resg["h2"], resg["nLL"]
        else:
            h2uncorr, nLLuncorr = 0, 0
        logging.info("just G: h2uncorr: {0}, nLLuncorr: {1}".format(
            h2uncorr, nLLuncorr))

    #########################################
    # Find a2, the best mixing for G_kernel and E_kernel
    #########################################

    if a2 is None:
        logging.info("Find best mixing for G_kernel and E_kernel")
        lmm1 = LMM()
        lmm1.setK(K0=G_kernel.val, K1=E_kernel.val, a2=0.5)
        lmm1.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmm1.sety(pheno.val[:, 0])
        if not just_testing:
            res1 = lmm1.findA2()
            h2, a2, nLLcorr = res1["h2"], res1["a2"], res1["nLL"]
            h2corr = h2 * (1 - a2)
            e2 = h2 * a2
        else:
            h2corr, e2, a2, nLLcorr = 0, 0, .5, 0
        logging.info(
            "G plus E mixture: h2corr: {0}, e2: {1}, a2: {2}, nLLcorr: {3}".
            format(h2corr, e2, a2, nLLcorr))
    else:
        h2corr, e2, nLLcorr = np.nan, np.nan, np.nan

    #########################################
    # Find a2_gxe2, the best mixing for G+E_kernel and the GxE kernel
    #########################################

    if not do_gxe2:
        gxe2, a2_gxe2, nLL_gxe2 = np.nan, np.nan, np.nan
    else:
        #Create the G+E kernel by mixing according to a2
        val = (1 - a2) * G_kernel.val + a2 * E_kernel.val
        GplusE_kernel = KernelData(iid=G_kernel.iid,
                                   val=val,
                                   parent_string="{0} G + {1} E".format(
                                       1 - a2, a2))
        #Don't need to standardize GplusE_kernel because it's the weighted combination of standardized kernels

        # Create GxE Kernel and then find the best mixing of it and GplusE
        logging.info("Find best mixing for GxE and GplusE_kernel")

        val = G_kernel.val * E_kernel.val
        if permute_times_index >= 0:
            #We shuffle the val, but not the iid, because doing both would cancel out
            np.random.seed(
                (permute_times_seed + permute_times_index) % 4294967295)
            new_index = np.arange(G_kernel.iid_count)
            np.random.shuffle(new_index)
            val = pstutil.sub_matrix(val, new_index, new_index)

        GxE_kernel = KernelData(
            iid=G_kernel.iid, val=val, parent_string="GxE"
        )  # recall that Python '*' is just element-wise multiplication
        GxE_kernel = GxE_kernel.standardize()

        lmm2 = LMM()
        lmm2.setK(K0=GplusE_kernel.val, K1=GxE_kernel.val, a2=0.5)
        lmm2.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmm2.sety(pheno.val[:, 0])
        if not just_testing:
            res2 = lmm2.findA2()
            gxe2, a2_gxe2, nLL_gxe2 = res2["h2"], res2["a2"], res2["nLL"]
            gxe2 *= a2_gxe2
        else:
            gxe2, a2_gxe2, nLL_gxe2 = 0, .5, 0
        logging.info(
            "G+E plus GxE mixture: gxe2: {0}, a2_gxe2: {1}, nLL_gxe2: {2}".
            format(gxe2, a2_gxe2, nLL_gxe2))

    #########################################
    # Return results
    #########################################

    ret = {
        "h2uncorr": h2uncorr,
        "nLLuncorr": nLLuncorr,
        "h2corr": h2corr,
        "e2": e2,
        "a2": a2,
        "nLLcorr": nLLcorr,
        "gxe2": gxe2,
        "a2_gxe2": a2_gxe2,
        "nLL_gxe2": nLL_gxe2,
        "alpha": alpha,
        "alpha_power": alpha_power,
        "phen": pheno.sid[0],
        "jackknife_index": jackknife_index,
        "jackknife_count": jackknife_count,
        "jackknife_seed": jackknife_seed,
        "permute_plus_index": permute_plus_index,
        "permute_plus_count": permute_plus_count,
        "permute_plus_seed": permute_plus_seed,
        "permute_times_index": permute_times_index,
        "permute_times_count": permute_times_count,
        "permute_times_seed": permute_times_seed
    }

    logging.info("run_line: {0}".format(ret))
    return ret
def work_item(arg_tuple):               
    (pheno, G_kernel, spatial_coor, spatial_iid, alpha,alpha_power,    # The main inputs
     (jackknife_index, jackknife_count, jackknife_seed),               # Jackknifing and permutations inputs
     (permute_plus_index, permute_plus_count, permute_plus_seed),
     (permute_times_index, permute_times_count, permute_times_seed),
     just_testing, do_uncorr, do_gxe2, a2) = arg_tuple                 # Shortcutting work

    #########################################
    # Remove any missing values from pheno
    #########################################
    pheno = pheno.read()
    pheno = pheno[pheno.val[:,0]==pheno.val[:,0],:] #Excludes NaN because NaN is not equal to NaN

    #########################################
    # Environment: Turn spatial info info a KernelData
    #########################################
    spatial_val = spatial_similarity(spatial_coor, alpha, power=alpha_power)
    E_kernel = KernelData(iid=spatial_iid,val=spatial_val)

    #########################################
    # Intersect, apply the jackknife or permutation, and then (because we now know the iids) standardize appropriately
    #########################################
    from pysnptools.util import intersect_apply
    G_kernel, E_kernel, pheno  = intersect_apply([G_kernel, E_kernel, pheno])

    if jackknife_index >= 0:
        assert jackknife_count <= G_kernel.iid_count, "expect the number of groups to be less than the number of iids"
        assert jackknife_index < jackknife_count, "expect the jackknife index to be less than the count"
        m_fold = cross_validation.KFold(n=G_kernel.iid_count, n_folds=jackknife_count, shuffle=True, random_state=jackknife_seed%4294967295)
        iid_index,_ = _nth(m_fold, jackknife_index)
        pheno = pheno[iid_index,:]
        G_kernel = G_kernel[iid_index]
        E_kernel = E_kernel[iid_index]

    if permute_plus_index >= 0:
        #We shuffle the val, but not the iid, because that would cancel out.
        #Integrate the permute_plus_index into the random.
        np.random.seed((permute_plus_seed + permute_plus_index)%4294967295)
        new_index = np.arange(G_kernel.iid_count)
        np.random.shuffle(new_index)
        E_kernel_temp = E_kernel[new_index].read()
        E_kernel = KernelData(iid=E_kernel.iid,val=E_kernel_temp.val,name="permutation {0}".format(permute_plus_index))

    pheno = pheno.read().standardize()       # defaults to Unit standardize
    G_kernel = G_kernel.read().standardize() # defaults to DiagKtoN standardize
    E_kernel = E_kernel.read().standardize() # defaults to DiagKtoN standardize

    #########################################
    # find h2uncoor, the best mixing weight of pure random noise and G_kernel
    #########################################

    if not do_uncorr:
        h2uncorr, nLLuncorr = np.nan,np.nan
    else:
        logging.info("Find best h2 for G_kernel")
        lmmg = LMM()
        lmmg.setK(K0=G_kernel.val)
        lmmg.setX(np.ones([G_kernel.iid_count,1])) # just a bias column
        lmmg.sety(pheno.val[:,0])
        if not just_testing:
            resg = lmmg.findH2()
            h2uncorr, nLLuncorr = resg["h2"], resg["nLL"]
        else:
            h2uncorr, nLLuncorr = 0,0
        logging.info("just G: h2uncorr: {0}, nLLuncorr: {1}".format(h2uncorr,nLLuncorr))
    
    #########################################
    # Find a2, the best mixing for G_kernel and E_kernel
    #########################################

    if a2 is None:
        logging.info("Find best mixing for G_kernel and E_kernel")
        lmm1 = LMM()
        lmm1.setK(K0=G_kernel.val, K1=E_kernel.val, a2=0.5)
        lmm1.setX(np.ones([G_kernel.iid_count,1])) # just a bias column
        lmm1.sety(pheno.val[:,0])
        if not just_testing:
            res1 = lmm1.findA2()
            h2, a2, nLLcorr = res1["h2"], res1["a2"], res1["nLL"]
            h2corr = h2 * (1-a2)
            e2 = h2 * a2
        else:
            h2corr, e2, a2, nLLcorr = 0,0,.5,0
        logging.info("G plus E mixture: h2corr: {0}, e2: {1}, a2: {2}, nLLcorr: {3}".format(h2corr,e2,a2,nLLcorr))
    else:
        h2corr, e2, nLLcorr = np.nan, np.nan, np.nan

    #########################################
    # Find a2_gxe2, the best mixing for G+E_kernel and the GxE kernel
    #########################################

    if not do_gxe2:
        gxe2, a2_gxe2, nLL_gxe2 = np.nan, np.nan, np.nan
    else:
        #Create the G+E kernel by mixing according to a2
        val=(1-a2)*G_kernel.val + a2*E_kernel.val
        GplusE_kernel = KernelData(iid=G_kernel.iid, val=val,name="{0} G + {1} E".format(1-a2,a2))
        #Don't need to standardize GplusE_kernel because it's the weighted combination of standardized kernels

        # Create GxE Kernel and then find the best mixing of it and GplusE
        logging.info("Find best mixing for GxE and GplusE_kernel")

        val=G_kernel.val * E_kernel.val
        if permute_times_index >= 0:
            #We shuffle the val, but not the iid, because doing both would cancel out
            np.random.seed((permute_times_seed + permute_times_index)%4294967295)
            new_index = np.arange(G_kernel.iid_count)
            np.random.shuffle(new_index)
            val = pstutil.sub_matrix(val, new_index, new_index)

        GxE_kernel = KernelData(iid=G_kernel.iid, val=val,name="GxE") # recall that Python '*' is just element-wise multiplication
        GxE_kernel = GxE_kernel.standardize()

        lmm2 = LMM()
        lmm2.setK(K0=GplusE_kernel.val, K1=GxE_kernel.val, a2=0.5)
        lmm2.setX(np.ones([G_kernel.iid_count,1])) # just a bias column
        lmm2.sety(pheno.val[:,0])
        if not just_testing:
            res2 = lmm2.findA2()
            gxe2, a2_gxe2, nLL_gxe2 = res2["h2"], res2["a2"], res2["nLL"]
            gxe2 *= a2_gxe2
        else:
            gxe2, a2_gxe2, nLL_gxe2 = 0,.5,0
        logging.info("G+E plus GxE mixture: gxe2: {0}, a2_gxe2: {1}, nLL_gxe2: {2}".format(gxe2, a2_gxe2, nLL_gxe2))
        
    #########################################
    # Return results
    #########################################

    ret = {"h2uncorr": h2uncorr, "nLLuncorr": nLLuncorr, "h2corr": h2corr, "e2":e2, "a2": a2, "nLLcorr": nLLcorr,
           "gxe2": gxe2, "a2_gxe2": a2_gxe2, "nLL_gxe2": nLL_gxe2, "alpha": alpha, "alpha_power":alpha_power, "phen": pheno.sid[0],
           "jackknife_index": jackknife_index, "jackknife_count":jackknife_count, "jackknife_seed":jackknife_seed,
           "permute_plus_index": permute_plus_index, "permute_plus_count":permute_plus_count, "permute_plus_seed":permute_plus_seed,
           "permute_times_index": permute_times_index, "permute_times_count":permute_times_count, "permute_times_seed":permute_times_seed
           }
    
    logging.info("run_line: {0}".format(ret))
    return ret