Exemple #1
0
    def test_sub_matrix(self):
        import pysnptools.util as pstutil
        np.random.seed(0)  # set seed so that results are deterministic
        matrix = np.random.rand(12, 7)  # create a 12 x 7 ndarray
        submatrix = pstutil.sub_matrix(matrix, [0, 2, 11],
                                       [6, 5, 4, 3, 2, 1, 0])
        assert matrix[2, 0] == submatrix[
            1, 6]  #The row # 2 is now #1, the column #0 is now #6.

        np.random.seed(0)  # set seed so that results are deterministic
        matrix = np.random.rand(12, 7, 3)  # create a 12 x 7 ndarray
        submatrix = pstutil.sub_matrix(matrix, [0, 2, 11],
                                       [6, 5, 4, 3, 2, 1, 0])
        assert matrix[2, 0, 1] == submatrix[
            1, 6, 1]  #The row # 2 is now #1, the column #0 is now #6.
    def _apply_sparray_or_slice_to_val(self, val, row_indexer_or_none, col_indexer_or_none, order, dtype, force_python_only):
        if (PstReader._is_all_slice(row_indexer_or_none) and PstReader._is_all_slice(col_indexer_or_none)  and not force_python_only and 
                (order == 'A' or (order == 'F' and val.flags['F_CONTIGUOUS']) or (order == 'C' and val.flags['C_CONTIGUOUS'])) and
                (dtype is None or  val.dtype == dtype)):
            return val, True

        row_indexer = PstReader._make_sparray_or_slice(row_indexer_or_none)
        col_indexer = PstReader._make_sparray_or_slice(col_indexer_or_none)
        if not force_python_only:
            row_index = PstReader._make_sparray_from_sparray_or_slice(self.row_count, row_indexer)
            col_index = PstReader._make_sparray_from_sparray_or_slice(self.col_count, col_indexer)
            sub_val = pstutil.sub_matrix(val, row_index, col_index, order=order, dtype=dtype)
            return sub_val, False


        if PstReader._is_all_slice(row_indexer) or PstReader._is_all_slice(col_indexer):
            sub_val = val[row_indexer, col_indexer] #!!is this faster than the C++?
        else: 
            row_index = PstReader._make_sparray_from_sparray_or_slice(self.row_count, row_indexer)
            col_index = PstReader._make_sparray_from_sparray_or_slice(self.col_count, col_indexer)
            #See http://stackoverflow.com/questions/21349133/numpy-array-integer-indexing-in-more-than-one-dimension
            sub_val = val[row_index.reshape(-1,1), col_index]

        assert len(sub_val.shape)==2, "Expect result of subsetting to be 2 dimensional"

        if not PstReader._array_properties_are_ok(sub_val, order, dtype):
            if order is None:
                order = "K"
            if dtype is None:
                dtype = sub_val.dtype
            sub_val = sub_val.astype(dtype, order, copy=True)

        shares_memory =  np.may_share_memory(val, sub_val)
        assert(PstReader._array_properties_are_ok(sub_val, order, dtype))
        return sub_val, shares_memory
Exemple #3
0
    def _apply_sparray_or_slice_to_val(self, val, row_indexer_or_none,
                                       col_indexer_or_none, order, dtype,
                                       force_python_only, num_threads):
        dtype = np.dtype(dtype)

        if (PstReader._is_all_slice(row_indexer_or_none)
                and PstReader._is_all_slice(col_indexer_or_none) and
            (order == 'A' or (order == 'F' and val.flags['F_CONTIGUOUS']) or
             (order == 'C' and val.flags['C_CONTIGUOUS']))
                and (dtype is None or val.dtype == dtype)):
            return val, True

        row_indexer = PstReader._make_sparray_or_slice(row_indexer_or_none)
        col_indexer = PstReader._make_sparray_or_slice(col_indexer_or_none)
        if not force_python_only:
            row_index = PstReader._make_sparray_from_sparray_or_slice(
                self.row_count, row_indexer)
            col_index = PstReader._make_sparray_from_sparray_or_slice(
                self.col_count, col_indexer)
            sub_val = pstutil.sub_matrix(val,
                                         row_index,
                                         col_index,
                                         order=order,
                                         dtype=dtype,
                                         num_threads=num_threads)
            return sub_val, False

        if PstReader._is_all_slice(row_indexer) or PstReader._is_all_slice(
                col_indexer):
            sub_val = val[row_indexer,
                          col_indexer]  #!!is this faster than the C++?
        else:
            row_index = PstReader._make_sparray_from_sparray_or_slice(
                self.row_count, row_indexer)
            col_index = PstReader._make_sparray_from_sparray_or_slice(
                self.col_count, col_indexer)
            #See http://stackoverflow.com/questions/21349133/numpy-array-integer-indexing-in-more-than-one-dimension
            sub_val = val[row_index.reshape(-1, 1), col_index]

        assert len(sub_val.shape) in {
            2, 3
        }, "Expect result of subsetting to be 2 or 3 dimensional"

        if not PstReader._array_properties_are_ok(sub_val, order, dtype):
            if order is None:
                order = "K"
            if dtype is None:
                dtype = sub_val.dtype
            sub_val = sub_val.astype(dtype, order, copy=True)

        shares_memory = np.may_share_memory(val, sub_val)
        assert (PstReader._array_properties_are_ok(sub_val, order, dtype))
        return sub_val, shares_memory
Exemple #4
0
    def _read(self, row_index_or_none, col_index_or_none, order, dtype,
              force_python_only, view_ok, num_threads):
        self._run_once()
        import pysnptools.util as pstutil

        if order == 'A':
            order = 'F'
        dtype = np.dtype(dtype)

        row_index_count = len(
            row_index_or_none
        ) if row_index_or_none is not None else self._iid_count  # turn to a count of the index positions e.g. all of them
        col_index = col_index_or_none if col_index_or_none is not None else np.arange(
            self._sid_count
        )  # turn to an array of index positions, e.g. 0,1,200,2200,10
        batch_index = col_index // self._block_size  #find the batch index of each index position, e.g. 0,0,0,2,0
        val = np.empty((row_index_count, len(col_index), 3),
                       order=order,
                       dtype=dtype)  #allocate memory for result
        list_batch_index = list(set(batch_index))
        for ii, i in enumerate(
                list_batch_index
        ):  #for each distinct batch index, generate dists #!!!fix up snpgen this way, too with ii
            i = int(
                i)  # convert np.uint64 to python int to avoid uint*int->float
            #LATER logging.info("working on distgen batch {0} of {1}".format(ii,len(list_batch_index))) #!!!why does this produce messages like 'working on distgen batch 8 of 2'?
            start = i * self._block_size  #e.g. 0 (then 2000)
            stop = start + self._block_size  #e.g. 1000, then 3000
            batch_val = self._get_val(start, stop,
                                      dtype)  # generate whole batch
            a = (
                batch_index == i
            )  #e.g. [True,True,True,False,True], then [False,False,False,True,False]
            b = col_index[a] - start  #e.g.  0,1,200,10, then 200
            val[:,
                a, :] = batch_val[:,
                                  b, :] if row_index_or_none is None else pstutil.sub_matrix(
                                      batch_val,
                                      row_index_or_none,
                                      b,
                                      num_threads=num_threads)

        return val
Exemple #5
0
    def _read(self, row_index_or_none, col_index_or_none, order, dtype,
              force_python_only, view_ok):
        self._run_once()
        import pysnptools.util as pstutil
        dtype = np.dtype(dtype)

        if order == 'A':
            order = 'F'

        row_index_count = len(
            row_index_or_none
        ) if row_index_or_none is not None else self._iid_count  # turn to a count of the index positions e.g. all of them
        col_index = col_index_or_none if col_index_or_none is not None else np.arange(
            self._sid_count
        )  # turn to an array of index positions, e.g. 0,1,200,2200,10
        batch_index = col_index // self._block_size  #find the batch index of each index position, e.g. 0,0,0,2,0
        val = np.empty((row_index_count, len(col_index)),
                       order=order,
                       dtype=dtype)  #allocate memory for result
        list_batch_index = list(set(batch_index))
        with log_in_place("working on snpgen batch", logging.INFO) as updater:
            for ii, i in enumerate(
                    list_batch_index
            ):  #for each distinct batch index, generate snps
                updater("{0} of {1}".format(ii, len(list_batch_index)))
                start = i * self._block_size  #e.g. 0 (then 2000)
                stop = start + self._block_size  #e.g. 1000, then 3000
                batch_val = self._get_val2(start,
                                           stop,
                                           order=order,
                                           dtype=dtype)  # generate whole batch
                a = (
                    batch_index == i
                )  #e.g. [True,True,True,False,True], then [False,False,False,True,False]
                b = col_index[a] - start  #e.g.  0,1,200,10, then 200
                val[:,
                    a] = batch_val[:,
                                   b] if row_index_or_none is None else pstutil.sub_matrix(
                                       batch_val, row_index_or_none, b)

        return val
Exemple #6
0
def work_item2(pheno, G_kernel, spatial_coor, spatial_iid, alpha, alpha_power,
               xxx_todo_changeme, xxx_todo_changeme1, xxx_todo_changeme2,
               just_testing, do_uncorr, do_gxe2, a2):

    #########################################
    # Load GPS info from filename if that's the way it is given
    ########################################
    (jackknife_index, jackknife_count, jackknife_seed) = xxx_todo_changeme
    (permute_plus_index, permute_plus_count,
     permute_plus_seed) = xxx_todo_changeme1
    (permute_times_index, permute_times_count,
     permute_times_seed) = xxx_todo_changeme2
    if isinstance(spatial_coor, str):
        assert spatial_iid is None, "if spatial_coor is a str, then spatial_iid should be None"
        gps_table = pd.read_csv(spatial_coor, delimiter=" ").dropna()
        spatial_iid = np.array([(v, v) for v in gps_table["id"].values])
        spatial_coor = gps_table[["south_new", "east_new"]].values

    #########################################
    # Remove any missing values from pheno
    ########################################
    assert pheno.sid_count == 1, "Expect only one pheno in work_item"
    pheno = pheno.read()
    pheno = pheno[pheno.val[:, 0] == pheno.
                  val[:, 0], :]  #Excludes NaN because NaN is not equal to NaN

    #########################################
    # Environment: Turn spatial info info a KernelData
    #########################################
    spatial_val = spatial_similarity(spatial_coor, alpha, power=alpha_power)
    E_kernel = KernelData(iid=spatial_iid, val=spatial_val)

    #########################################
    # Intersect, apply the jackknife or permutation, and then (because we now know the iids) standardize appropriately
    #########################################
    from pysnptools.util import intersect_apply
    G_kernel, E_kernel, pheno = intersect_apply([G_kernel, E_kernel, pheno])

    if jackknife_index >= 0:
        assert jackknife_count <= G_kernel.iid_count, "expect the number of groups to be less than the number of iids"
        assert jackknife_index < jackknife_count, "expect the jackknife index to be less than the count"
        m_fold = model_selection.KFold(n_splits=jackknife_count,
                                       shuffle=True,
                                       random_state=jackknife_seed %
                                       4294967295).split(
                                           list(range(G_kernel.iid_count)))
        iid_index, _ = _nth(m_fold, jackknife_index)
        pheno = pheno[iid_index, :]
        G_kernel = G_kernel[iid_index]
        E_kernel = E_kernel[iid_index]

    if permute_plus_index >= 0:
        #We shuffle the val, but not the iid, because that would cancel out.
        #Integrate the permute_plus_index into the random.
        np.random.seed((permute_plus_seed + permute_plus_index) % 4294967295)
        new_index = np.arange(G_kernel.iid_count)
        np.random.shuffle(new_index)
        E_kernel_temp = E_kernel[new_index].read()
        E_kernel = KernelData(
            iid=E_kernel.iid,
            val=E_kernel_temp.val,
            name="permutation {0}".format(permute_plus_index))

    pheno = pheno.read().standardize()  # defaults to Unit standardize
    G_kernel = G_kernel.read().standardize(
    )  # defaults to DiagKtoN standardize
    E_kernel = E_kernel.read().standardize(
    )  # defaults to DiagKtoN standardize

    #########################################
    # find h2uncoor, the best mixing weight of pure random noise and G_kernel
    #########################################

    if not do_uncorr:
        h2uncorr, nLLuncorr = np.nan, np.nan
    else:
        logging.info("Find best h2 for G_kernel")
        lmmg = LMM()
        lmmg.setK(K0=G_kernel.val)
        lmmg.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmmg.sety(pheno.val[:, 0])
        if not just_testing:
            resg = lmmg.findH2()
            h2uncorr, nLLuncorr = resg["h2"], resg["nLL"]
        else:
            h2uncorr, nLLuncorr = 0, 0
        logging.info("just G: h2uncorr: {0}, nLLuncorr: {1}".format(
            h2uncorr, nLLuncorr))

    #########################################
    # Find a2, the best mixing for G_kernel and E_kernel
    #########################################

    if a2 is None:
        logging.info("Find best mixing for G_kernel and E_kernel")
        lmm1 = LMM()
        lmm1.setK(K0=G_kernel.val, K1=E_kernel.val, a2=0.5)
        lmm1.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmm1.sety(pheno.val[:, 0])
        if not just_testing:
            res1 = lmm1.findA2()
            h2, a2, nLLcorr = res1["h2"], res1["a2"], res1["nLL"]
            h2corr = h2 * (1 - a2)
            e2 = h2 * a2
            h2corr_raw = h2
        else:
            h2corr, e2, a2, nLLcorr, h2corr_raw = 0, 0, .5, 0, 0
        logging.info(
            "G plus E mixture: h2corr: {0}, e2: {1}, a2: {2}, nLLcorr: {3} (h2corr_raw:{4})"
            .format(h2corr, e2, a2, nLLcorr, h2corr_raw))
    else:
        h2corr, e2, nLLcorr, h2corr_raw = np.nan, np.nan, np.nan, np.nan

    #########################################
    # Find a2_gxe2, the best mixing for G+E_kernel and the GxE kernel
    #########################################

    if not do_gxe2:
        gxe2, a2_gxe2, nLL_gxe2 = np.nan, np.nan, np.nan
    else:
        #Create the G+E kernel by mixing according to a2
        val = (1 - a2) * G_kernel.val + a2 * E_kernel.val
        GplusE_kernel = KernelData(iid=G_kernel.iid,
                                   val=val,
                                   name="{0} G + {1} E".format(1 - a2, a2))
        #Don't need to standardize GplusE_kernel because it's the weighted combination of standardized kernels

        # Create GxE Kernel and then find the best mixing of it and GplusE
        logging.info("Find best mixing for GxE and GplusE_kernel")

        val = G_kernel.val * E_kernel.val
        if permute_times_index >= 0:
            #We shuffle the val, but not the iid, because doing both would cancel out
            np.random.seed(
                (permute_times_seed + permute_times_index) % 4294967295)
            new_index = np.arange(G_kernel.iid_count)
            np.random.shuffle(new_index)
            val = pstutil.sub_matrix(val, new_index, new_index)

        GxE_kernel = KernelData(
            iid=G_kernel.iid, val=val, name="GxE"
        )  # recall that Python '*' is just element-wise multiplication
        GxE_kernel = GxE_kernel.standardize()

        lmm2 = LMM()
        lmm2.setK(K0=GplusE_kernel.val, K1=GxE_kernel.val, a2=0.5)
        lmm2.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmm2.sety(pheno.val[:, 0])
        if not just_testing:
            res2 = lmm2.findA2()
            gxe2, a2_gxe2, nLL_gxe2 = res2["h2"], res2["a2"], res2["nLL"]
            gxe2 *= a2_gxe2
        else:
            gxe2, a2_gxe2, nLL_gxe2 = 0, .5, 0
        logging.info(
            "G+E plus GxE mixture: gxe2: {0}, a2_gxe2: {1}, nLL_gxe2: {2}".
            format(gxe2, a2_gxe2, nLL_gxe2))

    #########################################
    # Return results
    #########################################

    ret = {
        "h2uncorr": h2uncorr,
        "nLLuncorr": nLLuncorr,
        "h2corr": h2corr,
        "h2corr_raw": h2corr_raw,
        "e2": e2,
        "a2": a2,
        "nLLcorr": nLLcorr,
        "gxe2": gxe2,
        "a2_gxe2": a2_gxe2,
        "nLL_gxe2": nLL_gxe2,
        "alpha": alpha,
        "alpha_power": alpha_power,
        "phen": np.array(pheno.sid, dtype='str')[0],
        "jackknife_index": jackknife_index,
        "jackknife_count": jackknife_count,
        "jackknife_seed": jackknife_seed,
        "permute_plus_index": permute_plus_index,
        "permute_plus_count": permute_plus_count,
        "permute_plus_seed": permute_plus_seed,
        "permute_times_index": permute_times_index,
        "permute_times_count": permute_times_count,
        "permute_times_seed": permute_times_seed
    }

    logging.info("run_line: {0}".format(ret))
    return ret
        GplusE_kernel = KernelData(iid=G_kernel.iid,
                                   val=val,
                                   name="{0} G + {1} E".format(1 - a2, a2))
        #Don't need to standardize GplusE_kernel because it's the weighted combination of standardized kernels

        # Create GxE Kernel and then find the best mixing of it and GplusE
        logging.info("Find best mixing for GxE and GplusE_kernel")

        val = G_kernel.val * E_kernel.val
        if permute_times_index >= 0:
            #We shuffle the val, but not the iid, because doing both would cancel out
            np.random.seed(
                (permute_times_seed + permute_times_index) % 4294967295)
            new_index = np.arange(G_kernel.iid_count)
            np.random.shuffle(new_index)
            val = pstutil.sub_matrix(val, new_index, new_index)

        GxE_kernel = KernelData(
            iid=G_kernel.iid, val=val, name="GxE"
        )  # recall that Python '*' is just element-wise multiplication
        GxE_kernel = GxE_kernel.standardize()

        lmm2 = LMM()
        lmm2.setK(K0=GplusE_kernel.val, K1=GxE_kernel.val, a2=0.5)
        lmm2.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmm2.sety(pheno.val[:, 0])
        if not just_testing:
            res2 = lmm2.findA2()
            gxe2, a2_gxe2, nLL_gxe2 = res2["h2"], res2["a2"], res2["nLL"]
            gxe2 *= a2_gxe2
        else:
def work_item(arg_tuple):
    (
        pheno,
        G_kernel,
        spatial_coor,
        spatial_iid,
        alpha,
        alpha_power,  # The main inputs
        (jackknife_index, jackknife_count,
         jackknife_seed),  # Jackknifing and permutations inputs
        (permute_plus_index, permute_plus_count, permute_plus_seed),
        (permute_times_index, permute_times_count, permute_times_seed),
        just_testing,
        do_uncorr,
        do_gxe2,
        a2) = arg_tuple  # Shortcutting work

    #########################################
    # Remove any missing values from pheno
    #########################################
    pheno = pheno.read()
    pheno = pheno[pheno.val[:, 0] == pheno.
                  val[:, 0], :]  #Excludes NaN because NaN is not equal to NaN

    #########################################
    # Environment: Turn spatial info info a KernelData
    #########################################
    spatial_val = spatial_similarity(spatial_coor, alpha, power=alpha_power)
    E_kernel = KernelData(iid=spatial_iid, val=spatial_val)

    #########################################
    # Intersect, apply the jackknife or permutation, and then (because we now know the iids) standardize appropriately
    #########################################
    from pysnptools.util import intersect_apply
    G_kernel, E_kernel, pheno = intersect_apply([G_kernel, E_kernel, pheno])

    if jackknife_index >= 0:
        assert jackknife_count <= G_kernel.iid_count, "expect the number of groups to be less than the number of iids"
        assert jackknife_index < jackknife_count, "expect the jackknife index to be less than the count"
        m_fold = cross_validation.KFold(n=G_kernel.iid_count,
                                        n_folds=jackknife_count,
                                        shuffle=True,
                                        random_state=jackknife_seed %
                                        4294967295)
        iid_index, _ = _nth(m_fold, jackknife_index)
        pheno = pheno[iid_index, :]
        G_kernel = G_kernel[iid_index]
        E_kernel = E_kernel[iid_index]

    if permute_plus_index >= 0:
        #We shuffle the val, but not the iid, because that would cancel out.
        #Integrate the permute_plus_index into the random.
        np.random.seed((permute_plus_seed + permute_plus_index) % 4294967295)
        new_index = np.arange(G_kernel.iid_count)
        np.random.shuffle(new_index)
        E_kernel_temp = E_kernel[new_index].read()
        E_kernel = KernelData(
            iid=E_kernel.iid,
            val=E_kernel_temp.val,
            parent_string="permutation {0}".format(permute_plus_index))

    pheno = pheno.read().standardize()  # defaults to Unit standardize
    G_kernel = G_kernel.read().standardize(
    )  # defaults to DiagKtoN standardize
    E_kernel = E_kernel.read().standardize(
    )  # defaults to DiagKtoN standardize

    #########################################
    # find h2uncoor, the best mixing weight of pure random noise and G_kernel
    #########################################

    if not do_uncorr:
        h2uncorr, nLLuncorr = np.nan, np.nan
    else:
        logging.info("Find best h2 for G_kernel")
        lmmg = LMM()
        lmmg.setK(K0=G_kernel.val)
        lmmg.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmmg.sety(pheno.val[:, 0])
        if not just_testing:
            resg = lmmg.findH2()
            h2uncorr, nLLuncorr = resg["h2"], resg["nLL"]
        else:
            h2uncorr, nLLuncorr = 0, 0
        logging.info("just G: h2uncorr: {0}, nLLuncorr: {1}".format(
            h2uncorr, nLLuncorr))

    #########################################
    # Find a2, the best mixing for G_kernel and E_kernel
    #########################################

    if a2 is None:
        logging.info("Find best mixing for G_kernel and E_kernel")
        lmm1 = LMM()
        lmm1.setK(K0=G_kernel.val, K1=E_kernel.val, a2=0.5)
        lmm1.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmm1.sety(pheno.val[:, 0])
        if not just_testing:
            res1 = lmm1.findA2()
            h2, a2, nLLcorr = res1["h2"], res1["a2"], res1["nLL"]
            h2corr = h2 * (1 - a2)
            e2 = h2 * a2
        else:
            h2corr, e2, a2, nLLcorr = 0, 0, .5, 0
        logging.info(
            "G plus E mixture: h2corr: {0}, e2: {1}, a2: {2}, nLLcorr: {3}".
            format(h2corr, e2, a2, nLLcorr))
    else:
        h2corr, e2, nLLcorr = np.nan, np.nan, np.nan

    #########################################
    # Find a2_gxe2, the best mixing for G+E_kernel and the GxE kernel
    #########################################

    if not do_gxe2:
        gxe2, a2_gxe2, nLL_gxe2 = np.nan, np.nan, np.nan
    else:
        #Create the G+E kernel by mixing according to a2
        val = (1 - a2) * G_kernel.val + a2 * E_kernel.val
        GplusE_kernel = KernelData(iid=G_kernel.iid,
                                   val=val,
                                   parent_string="{0} G + {1} E".format(
                                       1 - a2, a2))
        #Don't need to standardize GplusE_kernel because it's the weighted combination of standardized kernels

        # Create GxE Kernel and then find the best mixing of it and GplusE
        logging.info("Find best mixing for GxE and GplusE_kernel")

        val = G_kernel.val * E_kernel.val
        if permute_times_index >= 0:
            #We shuffle the val, but not the iid, because doing both would cancel out
            np.random.seed(
                (permute_times_seed + permute_times_index) % 4294967295)
            new_index = np.arange(G_kernel.iid_count)
            np.random.shuffle(new_index)
            val = pstutil.sub_matrix(val, new_index, new_index)

        GxE_kernel = KernelData(
            iid=G_kernel.iid, val=val, parent_string="GxE"
        )  # recall that Python '*' is just element-wise multiplication
        GxE_kernel = GxE_kernel.standardize()

        lmm2 = LMM()
        lmm2.setK(K0=GplusE_kernel.val, K1=GxE_kernel.val, a2=0.5)
        lmm2.setX(np.ones([G_kernel.iid_count, 1]))  # just a bias column
        lmm2.sety(pheno.val[:, 0])
        if not just_testing:
            res2 = lmm2.findA2()
            gxe2, a2_gxe2, nLL_gxe2 = res2["h2"], res2["a2"], res2["nLL"]
            gxe2 *= a2_gxe2
        else:
            gxe2, a2_gxe2, nLL_gxe2 = 0, .5, 0
        logging.info(
            "G+E plus GxE mixture: gxe2: {0}, a2_gxe2: {1}, nLL_gxe2: {2}".
            format(gxe2, a2_gxe2, nLL_gxe2))

    #########################################
    # Return results
    #########################################

    ret = {
        "h2uncorr": h2uncorr,
        "nLLuncorr": nLLuncorr,
        "h2corr": h2corr,
        "e2": e2,
        "a2": a2,
        "nLLcorr": nLLcorr,
        "gxe2": gxe2,
        "a2_gxe2": a2_gxe2,
        "nLL_gxe2": nLL_gxe2,
        "alpha": alpha,
        "alpha_power": alpha_power,
        "phen": pheno.sid[0],
        "jackknife_index": jackknife_index,
        "jackknife_count": jackknife_count,
        "jackknife_seed": jackknife_seed,
        "permute_plus_index": permute_plus_index,
        "permute_plus_count": permute_plus_count,
        "permute_plus_seed": permute_plus_seed,
        "permute_times_index": permute_times_index,
        "permute_times_count": permute_times_count,
        "permute_times_seed": permute_times_seed
    }

    logging.info("run_line: {0}".format(ret))
    return ret
def work_item(arg_tuple):               
    (pheno, G_kernel, spatial_coor, spatial_iid, alpha,alpha_power,    # The main inputs
     (jackknife_index, jackknife_count, jackknife_seed),               # Jackknifing and permutations inputs
     (permute_plus_index, permute_plus_count, permute_plus_seed),
     (permute_times_index, permute_times_count, permute_times_seed),
     just_testing, do_uncorr, do_gxe2, a2) = arg_tuple                 # Shortcutting work

    #########################################
    # Remove any missing values from pheno
    #########################################
    pheno = pheno.read()
    pheno = pheno[pheno.val[:,0]==pheno.val[:,0],:] #Excludes NaN because NaN is not equal to NaN

    #########################################
    # Environment: Turn spatial info info a KernelData
    #########################################
    spatial_val = spatial_similarity(spatial_coor, alpha, power=alpha_power)
    E_kernel = KernelData(iid=spatial_iid,val=spatial_val)

    #########################################
    # Intersect, apply the jackknife or permutation, and then (because we now know the iids) standardize appropriately
    #########################################
    from pysnptools.util import intersect_apply
    G_kernel, E_kernel, pheno  = intersect_apply([G_kernel, E_kernel, pheno])

    if jackknife_index >= 0:
        assert jackknife_count <= G_kernel.iid_count, "expect the number of groups to be less than the number of iids"
        assert jackknife_index < jackknife_count, "expect the jackknife index to be less than the count"
        m_fold = cross_validation.KFold(n=G_kernel.iid_count, n_folds=jackknife_count, shuffle=True, random_state=jackknife_seed%4294967295)
        iid_index,_ = _nth(m_fold, jackknife_index)
        pheno = pheno[iid_index,:]
        G_kernel = G_kernel[iid_index]
        E_kernel = E_kernel[iid_index]

    if permute_plus_index >= 0:
        #We shuffle the val, but not the iid, because that would cancel out.
        #Integrate the permute_plus_index into the random.
        np.random.seed((permute_plus_seed + permute_plus_index)%4294967295)
        new_index = np.arange(G_kernel.iid_count)
        np.random.shuffle(new_index)
        E_kernel_temp = E_kernel[new_index].read()
        E_kernel = KernelData(iid=E_kernel.iid,val=E_kernel_temp.val,name="permutation {0}".format(permute_plus_index))

    pheno = pheno.read().standardize()       # defaults to Unit standardize
    G_kernel = G_kernel.read().standardize() # defaults to DiagKtoN standardize
    E_kernel = E_kernel.read().standardize() # defaults to DiagKtoN standardize

    #########################################
    # find h2uncoor, the best mixing weight of pure random noise and G_kernel
    #########################################

    if not do_uncorr:
        h2uncorr, nLLuncorr = np.nan,np.nan
    else:
        logging.info("Find best h2 for G_kernel")
        lmmg = LMM()
        lmmg.setK(K0=G_kernel.val)
        lmmg.setX(np.ones([G_kernel.iid_count,1])) # just a bias column
        lmmg.sety(pheno.val[:,0])
        if not just_testing:
            resg = lmmg.findH2()
            h2uncorr, nLLuncorr = resg["h2"], resg["nLL"]
        else:
            h2uncorr, nLLuncorr = 0,0
        logging.info("just G: h2uncorr: {0}, nLLuncorr: {1}".format(h2uncorr,nLLuncorr))
    
    #########################################
    # Find a2, the best mixing for G_kernel and E_kernel
    #########################################

    if a2 is None:
        logging.info("Find best mixing for G_kernel and E_kernel")
        lmm1 = LMM()
        lmm1.setK(K0=G_kernel.val, K1=E_kernel.val, a2=0.5)
        lmm1.setX(np.ones([G_kernel.iid_count,1])) # just a bias column
        lmm1.sety(pheno.val[:,0])
        if not just_testing:
            res1 = lmm1.findA2()
            h2, a2, nLLcorr = res1["h2"], res1["a2"], res1["nLL"]
            h2corr = h2 * (1-a2)
            e2 = h2 * a2
        else:
            h2corr, e2, a2, nLLcorr = 0,0,.5,0
        logging.info("G plus E mixture: h2corr: {0}, e2: {1}, a2: {2}, nLLcorr: {3}".format(h2corr,e2,a2,nLLcorr))
    else:
        h2corr, e2, nLLcorr = np.nan, np.nan, np.nan

    #########################################
    # Find a2_gxe2, the best mixing for G+E_kernel and the GxE kernel
    #########################################

    if not do_gxe2:
        gxe2, a2_gxe2, nLL_gxe2 = np.nan, np.nan, np.nan
    else:
        #Create the G+E kernel by mixing according to a2
        val=(1-a2)*G_kernel.val + a2*E_kernel.val
        GplusE_kernel = KernelData(iid=G_kernel.iid, val=val,name="{0} G + {1} E".format(1-a2,a2))
        #Don't need to standardize GplusE_kernel because it's the weighted combination of standardized kernels

        # Create GxE Kernel and then find the best mixing of it and GplusE
        logging.info("Find best mixing for GxE and GplusE_kernel")

        val=G_kernel.val * E_kernel.val
        if permute_times_index >= 0:
            #We shuffle the val, but not the iid, because doing both would cancel out
            np.random.seed((permute_times_seed + permute_times_index)%4294967295)
            new_index = np.arange(G_kernel.iid_count)
            np.random.shuffle(new_index)
            val = pstutil.sub_matrix(val, new_index, new_index)

        GxE_kernel = KernelData(iid=G_kernel.iid, val=val,name="GxE") # recall that Python '*' is just element-wise multiplication
        GxE_kernel = GxE_kernel.standardize()

        lmm2 = LMM()
        lmm2.setK(K0=GplusE_kernel.val, K1=GxE_kernel.val, a2=0.5)
        lmm2.setX(np.ones([G_kernel.iid_count,1])) # just a bias column
        lmm2.sety(pheno.val[:,0])
        if not just_testing:
            res2 = lmm2.findA2()
            gxe2, a2_gxe2, nLL_gxe2 = res2["h2"], res2["a2"], res2["nLL"]
            gxe2 *= a2_gxe2
        else:
            gxe2, a2_gxe2, nLL_gxe2 = 0,.5,0
        logging.info("G+E plus GxE mixture: gxe2: {0}, a2_gxe2: {1}, nLL_gxe2: {2}".format(gxe2, a2_gxe2, nLL_gxe2))
        
    #########################################
    # Return results
    #########################################

    ret = {"h2uncorr": h2uncorr, "nLLuncorr": nLLuncorr, "h2corr": h2corr, "e2":e2, "a2": a2, "nLLcorr": nLLcorr,
           "gxe2": gxe2, "a2_gxe2": a2_gxe2, "nLL_gxe2": nLL_gxe2, "alpha": alpha, "alpha_power":alpha_power, "phen": pheno.sid[0],
           "jackknife_index": jackknife_index, "jackknife_count":jackknife_count, "jackknife_seed":jackknife_seed,
           "permute_plus_index": permute_plus_index, "permute_plus_count":permute_plus_count, "permute_plus_seed":permute_plus_seed,
           "permute_times_index": permute_times_index, "permute_times_count":permute_times_count, "permute_times_seed":permute_times_seed
           }
    
    logging.info("run_line: {0}".format(ret))
    return ret