def save_evaluations(user, population, optim):
    # define pearson correlation coefficient vector
    pearsons = cupy.zeros(population.shape[0], dtype=cupy.float64)
    # define mse vector
    mse = cupy.zeros(population.shape[0], dtype=cupy.float64)
    # define common gene counter vector
    common_genes = cupy.zeros(population.shape[0], dtype=cupy.int64)
    # define cosine similarity vector
    cosines = cupy.zeros(population.shape[0], dtype=cupy.float64)
    # compute vectors
    for i in range(population.shape[0]):
        pearsons[i] = cupy.corrcoef(population[i], optim)[0, 1]
        mse[i] = (cupy.square(optim - population[i])).mean(axis=None)
        common_genes[i] = evaluate_chromosome(population[i], optim)
        cosines[i] = 1 - scipy.spatial.distance.cosine(optim, population[i])
    # save vectors to txt files using as prefix the user id
    cupy.savetxt("user_" + str(user) + "-pearsons.txt",
                 pearsons,
                 delimiter="\t")
    cupy.savetxt("user_" + str(user) + "-mse.txt", mse, delimiter="\t")
    cupy.savetxt("user_" + str(user) + "-common_genes.txt",
                 common_genes.astype(int),
                 fmt='%i',
                 delimiter="\t")
    cupy.savetxt("user_" + str(user) + "-cosines.txt", cosines, delimiter="\t")
    return
Esempio n. 2
0
def eval_emo_lex(derived_emo_lex, emo_lex, trans, induct_emos_file, induct_emos_eval_file, emotion):
    print("Number of derived emotion ratings:", len(derived_emo_lex))
    derived_emos = []
    real_emos = []
    words = []
    trans = {word_src: tgt_words for word_src, tgt_words in trans}
    for word, emo in derived_emo_lex.items():
        translations = ",".join([t[0] for t in trans[word]])
        induct_emos_file.write(f"{word}\t{translations}\t{emotion}\t{emo}\n")
        real_emo = emo_lex.get(word, None)
        if real_emo:
            induct_emos_eval_file.write(f"{word}\t{translations}\t{emotion}\t{emo}\t{real_emo}\n")
            derived_emos.append(emo)
            real_emos.append(real_emo)
            words.append(word)
    
    print("Coverage in test set:", len(derived_emos) / len(derived_emo_lex))

    derived_emos = np.array(derived_emos, dtype=float)
    real_emos = np.array(real_emos, dtype=float)
    corr_coeff = np.corrcoef(derived_emos, real_emos, rowvar=False)
    top_words = np.argsort(-derived_emos)[:10]
    print(derived_emos[top_words])
    top_words = [words[int(idx)] for idx in top_words]
    print(top_words)
    corr_coeff = np.around(corr_coeff[0, 1], 3)
    print("Correlation:", corr_coeff)
    return [corr_coeff, len(derived_emo_lex), derived_emos.shape[0]]
Esempio n. 3
0
    def _calcGeneSNPcorr(self, cr, gene, REF, useAll=False):

        if self._joint and self._MAP is not None:
            G = self._GENEID[gene]
            P = SortedSet(REF[str(cr)][1].irange(G[1] - self._window,
                                                 G[2] + self._window))

            if gene in self._MAP:
                P.update(
                    list(REF[str(cr)][0].getSNPsPos(
                        list(self._MAP[gene].keys()))))
                #P = list(set(P))

        elif self._MAP is None:
            G = self._GENEID[gene]

            P = REF[str(cr)][1].irange(G[1] - self._window,
                                       G[2] + self._window)
        else:
            if gene in self._MAP:
                P = set(REF[str(cr)][0].getSNPsPos(list(
                    self._MAP[gene].keys())))
                useAll = True
            else:
                P = []

        DATA = REF[str(cr)][0].get(list(P))

        filtered = {}

        #use = []
        #RID = []

        # Sort out
        for D in DATA:
            # Select
            if (D[0] in self._GWAS or useAll) and (D[1] > self._MAF) and (
                    D[0] not in filtered or D[1] < filtered[D[0]][0]):
                filtered[D[0]] = [D[1], D[2]]
                #use.append(D[2])
                #RID.append(s)

        # Calc corr
        RID = list(filtered.keys())
        use = []
        for i in range(0, len(RID)):
            use.append(filtered[RID[i]][1])

        use = np.array(use)

        if len(use) > 1:
            if self._useGPU:
                C = cp.asnumpy(cp.corrcoef(cp.asarray(use)))
            else:
                C = np.corrcoef(use)
        else:
            C = np.ones((1, 1))

        return C, np.array(RID)
Esempio n. 4
0
    def __lioness_loop(self):
        """
        Description:
            Initialize instance of Lioness class and load data.

        Outputs:
            self.total_lioness_network: An edge-by-sample matrix containing sample-specific networks.
        """
        for i in self.indexes:
            print("Running LIONESS for sample %d:" % (i+1))
            idx = [x for x in range(self.n_conditions) if x != i]  # all samples except i
            with Timer("Computing coexpression network:"):
                if self.computing=='gpu':
                    import cupy as cp
                    correlation_matrix = cp.corrcoef(self.expression_matrix[:, idx])
                    if cp.isnan(correlation_matrix).any():
                        cp.fill_diagonal(correlation_matrix, 1)
                        correlation_matrix = cp.nan_to_num(correlation_matrix)
                    correlation_matrix=cp.asnumpy(correlation_matrix)
                else:
                    correlation_matrix = np.corrcoef(self.expression_matrix[:, idx])
                    if np.isnan(correlation_matrix).any():
                        np.fill_diagonal(correlation_matrix, 1)
                        correlation_matrix = np.nan_to_num(correlation_matrix)

            with Timer("Normalizing networks:"):
                correlation_matrix_orig = correlation_matrix # save matrix before normalization
                correlation_matrix = self._normalize_network(correlation_matrix)

            with Timer("Inferring LIONESS network:"):
                if self.motif_matrix is not None:
                    del correlation_matrix_orig
                    subset_panda_network = self.panda_loop(correlation_matrix, np.copy(self.motif_matrix), np.copy(self.ppi_matrix),self.computing)
                else:
                    del correlation_matrix
                    subset_panda_network = correlation_matrix_orig

            lioness_network = self.n_conditions * (self.network - subset_panda_network) + subset_panda_network

            with Timer("Saving LIONESS network %d to %s using %s format:" % (i+1, self.save_dir, self.save_fmt)):
                path = os.path.join(self.save_dir, "lioness.%d.%s" % (i+1, self.save_fmt))
                if self.save_fmt == 'txt':
                    np.savetxt(path, lioness_network)
                elif self.save_fmt == 'npy':
                    np.save(path, lioness_network)
                elif self.save_fmt == 'mat':
                    from scipy.io import savemat
                    savemat(path, {'PredNet': lioness_network})
                else:
                    print("Unknown format %s! Use npy format instead." % self.save_fmt)
                    np.save(path, lioness_network)
            if i == 0:
                self.total_lioness_network = np.fromstring(np.transpose(lioness_network).tostring(),dtype=lioness_network.dtype)
            else:
                self.total_lioness_network=np.column_stack((self.total_lioness_network ,np.fromstring(np.transpose(lioness_network).tostring(),dtype=lioness_network.dtype)))

        return self.total_lioness_network
Esempio n. 5
0
 def estim_distance_matrix(self,
                           *,
                           from_sample=0,
                           to_sample=20000 * 60 * 5,
                           dist_coef=143,
                           dist_power=1 / 3):
     if self.estimated_coordinates is None:
         filtdata = self.filtered_data(from_sample=from_sample,
                                       to_sample=to_sample)
         corrs = cp.corrcoef(filtdata)
         corrs[corrs < 0] == 1e-10
         distances = dist_coef / corrs**(dist_power) - dist_coef
         distances = distances.get()
         self.distance_matrix = distances
         return distances
Esempio n. 6
0
from sklearn.datasets import load_digits

from numpy.testing import assert_almost_equal
from numpy.testing import assert_array_equal
from numpy.testing import assert_array_almost_equal

digits_data = load_digits()
X_digits = digits_data.data

norm = lambda x: numpy.sqrt((x * x).sum(axis=1)).reshape(x.shape[0], 1)
cosine = lambda x: numpy.dot(x, x.T) / (norm(x).dot(norm(x).T))

X_digits_sparse = scipy.sparse.csr_matrix(cosine(X_digits))

X_digits_corr_cupy = cupy.corrcoef(cupy.array(X_digits), rowvar=True)**2
X_digits_cosine_cupy = cupy.array(cosine(X_digits))

digits_corr_ranking = [
    424, 1647, 396, 339, 1030, 331, 983, 1075, 1482, 1539, 1282, 493, 885, 823,
    1051, 236, 537, 1161, 345, 1788, 1432, 1634, 1718, 1676, 146, 1286, 655,
    1292, 556, 533, 1545, 520, 1711, 1428, 620, 1276, 305, 438, 1026, 183, 2,
    384, 1012, 798, 213, 1291, 162, 1206, 227, 1655, 233, 1508, 410, 1295,
    1312, 1350, 514, 938, 579, 1066, 82, 164, 948, 1588, 1294, 1682, 943, 517,
    959, 1429, 762, 898, 1556, 881, 1470, 1549, 1325, 1568, 937, 347, 1364,
    126, 732, 1168, 241, 573, 731, 815, 864, 1639, 1570, 411, 1086, 696, 870,
    1156, 353, 160, 1381, 326
]

digits_corr_gains = [
    736.794, 114.2782, 65.4154, 61.3037, 54.5428, 38.7506, 34.097, 32.6649,
    def reconstruction_gpu(self,
                           signals,
                           stop_criteria,
                           alpha_1,
                           alpha_2,
                           alpha_3,
                           alpha_4,
                           max_iterations,
                           iterations=False,
                           guess=None,
                           verbose=False):
        """Apply the minimum fisher reconstruction algorithm for a given set of measurements from tomography.

        input:
            signals: array
                Array of signals from each sensor ordered like in "projections".
            stop_criteria: float
                Average different between iterations to admit convergence as a percentage between 0 and 1.
            alpha_1, alpha_2, alpha_3, alpha_4: float
                Regularization weights. Horizontal derivative. Vertical derivative. Outside Norm. Inside Norm.
            max_iterations: int
                Maximum number of iterations before algorithm admits non-convergence
            iterations: boolean, optional
                If set to `True`, returns every iteration step as a reconstruction. Defaults to False.
            guess: ndarray, optional
                Initial guess for the reconstructed profile.
            verbose: boolean, optional
                Print inner convergence messages. Defaults to `False`

        output:
            inner_loop_output: InnerLoopOutput class instance
                Object holding the data from the mfi inner loop.
        """

        # Aliasing for cleaner code --------------------------------------------------
        # Dh = self._cp.array(self._Dh, dtype=cp.float32)
        # Dht = cp.transpose(Dh)
        # Dv = cp.array(self._Dv, dtype=cp.float32)
        # Dvt = cp.transpose(Dv)
        # Pt = cp.array(self._Pt, dtype=cp.float32)
        n_rows = self._n_rows
        n_cols = self._n_cols

        Dh = self._gpu_Dh
        Dht = self._gpu_Dht
        Dv = self._gpu_Dv
        Dvt = self._gpu_Dvt
        Pt = self._gpu_Pt
        PtP = self._gpu_PtP
        ItIi = self._gpu_ItIi
        ItIo = self._gpu_ItIo

        # Instantiate f vector of signals --------------
        f = cp.array(signals, dtype=cp.float32)

        # -----------------------------  FIRST ITERATION  -------------------------------------------------------------

        # First guess to g is uniform plasma distribution --------------------------
        if guess is None:
            g_old = cp.ones(n_rows * n_cols, dtype=cp.float32)
        else:
            g_old = cp.array(guess, dtype=cp.float32)
            g_old[g_old < 1e-20] = 1e-20

        # List of emissivities -----------------------------------------------------
        g_list = []

        # Weight matrix ------------------------------------------------------------
        W = cp.diag(1.0 / cp.abs(g_old))
        # cp.asnumpy(W)

        # Fisher information (weighted derivatives) --------------------------------
        DtWDh = cp.dot(Dht, cp.dot(W, Dh))
        # cp.asnumpy(DtWDh)
        DtWDv = cp.dot(Dvt, cp.dot(W, Dv))
        # cp.asnumpy(DtWDh)

        # Inversion and calculation of vector g, storage of first guess ------------

        inv = cp.linalg.inv(alpha_1 * DtWDh + alpha_2 * DtWDv + PtP +
                            alpha_3 * ItIo + alpha_4 * ItIi)
        # cp.asnumpy(inv)

        M = cp.dot(inv, Pt)
        # cp.asnumpy(M)

        g_old = cp.dot(M, f)
        # cp.asnumpy(g_old)

        # first_g = cp.array(g_old)
        if iterations:
            g_list.append(cp.asnumpy(g_old.reshape((n_rows, n_cols))))

        # Iterative process --------------------------------------------------------
        for i in range(2, max_iterations + 1):

            g_old[g_old < 1e-20] = 1e-20

            W = cp.diag(1.0 / cp.abs(g_old))
            # cp.asnumpy(W)

            DtWDh = cp.dot(Dht, cp.dot(W, Dh))
            # cp.asnumpy(DtWDh)
            DtWDv = cp.dot(Dvt, cp.dot(W, Dv))
            # cp.asnumpy(DtWDv)

            inv = cp.linalg.inv(alpha_1 * DtWDh + alpha_2 * DtWDv + PtP +
                                alpha_3 * ItIo + alpha_4 * ItIi)
            # cp.asnumpy(inv)

            M = cp.dot(inv, Pt)
            # cp.asnumpy(M)

            g_new = cp.dot(M, f)
            # cp.asnumpy(g_new)

            # plt.figure()
            # plt.imshow(g_new.reshape((n_rows, n_cols)))

            # error = np.sum(np.abs((g_new[g_new > 1e-5] - g_old[g_new > 1e-5]) / g_new[g_new > 1e-5])) / len(g_new > 1e-5)
            # error = cp.sum(cp.abs(g_new - g_old)) / cp.sum(cp.abs(first_g))
            # error = cp.sum(cp.abs(g_new - g_old)**2) / cp.max(g_new)**2
            cov = 1. - 0.5 * ((cp.corrcoef(g_new, g_old)) + cp.corrcoef(
                g_new.reshape((n_rows, n_cols)).T.flatten(),
                g_old.reshape((n_rows, n_cols)).T.flatten()))
            error = cov[0, 1]
            # cp.asnumpy(error)

            if verbose:
                print("Iteration %d changed by %.4f%%" % (i, error * 100.))

            g_old = cp.array(g_new)  # Explicitly copy because python will not
            # cp.asnumpy(g_old)

            if iterations:
                g_list.append(cp.asnumpy(g_new.reshape((n_rows, n_cols))))

            if ((i >= 5) and (error > 0.90)) or cp.isnan(error):
                print("WARNING: Minimum Fisher is not converging, aborting...")
                convergence_flag = False
                break

            elif error < stop_criteria:
                print("Minimum Fisher converged after %d iterations." % i)
                convergence_flag = True
                break

            elif i == max_iterations:  # Break just before the `for loop` does
                print(
                    "WARNING: Minimum Fisher did not converge after %d iterations."
                    % i)
                convergence_flag = False
                break

        if not iterations:
            g_list.append(cp.asnumpy(g_new.reshape((n_rows, n_cols))))

        # return np.array(g_list)
        return InnerLoopOutput(iterations=g_list,
                               n_rows=n_rows,
                               n_cols=n_cols,
                               convergence_flag=convergence_flag)
Esempio n. 8
0
    def _calcMultiGeneSNPcorr(self, cr, genes, REF, wAlleles=True):

        filtered = {}

        use = []
        RID = []
        pos = []

        for gene in genes:
            DATA = []
            if self._joint and self._MAP is not None:
                G = self._GENEID[gene]
                P = SortedSet(REF[str(cr)][1].irange(G[1] - self._window,
                                                     G[2] + self._window))

                if gene in self._MAP:
                    P.update(
                        list(REF[str(cr)][0].getSNPsPos(self._MAP[gene][0])))
                    #P = list(set(P))

            elif self._MAP is None:
                G = self._GENEID[gene]

                P = REF[str(cr)][1].irange(G[1] - self._window,
                                           G[2] + self._window)
            else:
                if gene in self._MAP:
                    P = set(REF[str(cr)][0].getSNPsPos(self._MAP[gene][0]))
                else:
                    P = []

            DATA = REF[str(cr)][0].get(list(P))

            # Sort out
            for D in DATA:
                # Select
                if D[0] in self._GWAS and D[1] > self._MAF and (
                        D[0] not in filtered or filtered[D[0]][0] < D[1]) and (
                            not wAlleles or
                            (self._GWAS_alleles[D[0]][0] == D[3]
                             and self._GWAS_alleles[D[0]][1] == D[4])):
                    filtered[D[0]] = [D[1], D[2]]
                    #use.append(D[2])
                    #RID.append(s)

            pos.append(len(filtered))

        # Calc corr
        RID = list(filtered.keys())
        use = []
        for i in range(0, len(RID)):
            use.append(filtered[RID[i]][1])

        use = np.array(use)

        if len(use) > 1:
            if self._useGPU:
                C = cp.asnumpy(cp.corrcoef(cp.asarray(use)))
            else:
                C = np.corrcoef(use)
        else:
            C = np.ones((1, 1))

        return C, np.array(RID), pos
Esempio n. 9
0
def splitAllClusters(ctx, flag):
    # I call this algorithm "bimodal pursuit"
    # split clusters if they have bimodal projections
    # the strategy is to maximize a bimodality score and find a single vector projection
    # that maximizes it. If the distribution along that maximal projection crosses a
    # bimodality threshold, then the cluster is split along that direction
    # it only uses the PC features for each spike, stored in ir.cProjPC

    params = ctx.params
    probe = ctx.probe
    ir = ctx.intermediate
    Nchan = ctx.probe.Nchan

    wPCA = cp.asarray(
        ir.wPCA
    )  # use PCA projections to reconstruct templates when we do splits
    assert wPCA.shape[1] == 3

    # Take intermediate arrays from context.
    st3 = cp.asnumpy(ir.st3_m)
    cProjPC = ir.cProjPC
    dWU = ir.dWU

    # For the following arrays that will be overwritten by this function, try to get
    # it from a previous call to this function (as it is called twice), otherwise
    # get it from before (without the _s suffix).
    W = ir.get('W_s', ir.W)
    simScore = ir.get('simScore_s', ir.simScore)
    iNeigh = ir.get('iNeigh_s', ir.iNeigh)
    iNeighPC = ir.get('iNeighPC_s', ir.iNeighPC)

    # this is the threshold for splits, and is one of the main parameters users can change
    ccsplit = params.AUCsplit

    NchanNear = min(Nchan, 32)
    Nnearest = min(Nchan, 32)
    sigmaMask = params.sigmaMask

    ik = -1
    Nfilt = W.shape[1]
    nsplits = 0

    # determine what channels each template lives on
    iC, mask, C2C = getClosestChannels(probe, sigmaMask, NchanNear)

    # the waveforms must be aligned to this sample
    nt0min = params.nt0min
    # find the peak abs channel for each template
    iW = np.argmax(np.abs((dWU[nt0min - 1, :, :])), axis=0)

    # keep track of original cluster for each cluster. starts with all clusters being their
    # own origin.
    isplit = np.arange(Nfilt)
    dt = 1. / 1000
    nccg = 0

    while ik < Nfilt:
        if ik % 100 == 0:
            # periodically write updates
            logger.info(
                f'Found {nsplits} splits, checked {ik}/{Nfilt} clusters, nccg {nccg}'
            )
        ik += 1

        isp = (st3[:, 1] == ik)  # get all spikes from this cluster
        nSpikes = isp.sum()
        logger.debug(f"Splitting template {ik}/{Nfilt} with {nSpikes} spikes.")
        free_gpu_memory()

        if nSpikes < 300:
            # TODO: move_to_config
            # do not split if fewer than 300 spikes (we cannot estimate
            # cross-correlograms accurately)
            continue

        ss = st3[isp, 0] / params.fs  # convert to seconds

        clp0 = cProjPC[isp, :, :]  # get the PC projections for these spikes
        clp0 = cp.asarray(clp0, dtype=cp.float32)  # upload to the GPU
        clp0 = clp0.reshape((clp0.shape[0], -1), order='F')
        clp = clp0 - mean(clp0, axis=0)  # mean center them

        isp = np.nonzero(isp)[0]

        # subtract a running average, because the projections are NOT drift corrected
        clp = clp - my_conv2(clp, 250, 0)

        # now use two different ways to initialize the bimodal direction
        # the main script calls this function twice, and does both initializations

        if flag:
            u, s, v = svdecon(clp.T)
            #    u, v = -u, -v  # change sign for consistency with MATLAB
            w = u[:, 0]  # initialize with the top PC
        else:
            w = mean(clp0, axis=0
                     )  # initialize with the mean of NOT drift-corrected trace
            w = w / cp.sum(w**2)**0.5  # unit-normalize

        # initial projections of waveform PCs onto 1D vector
        x = cp.dot(clp, w)
        s1 = var(
            x[x > mean(x)])  # initialize estimates of variance for the first
        s2 = var(x[
            x < mean(x)])  # and second gaussian in the mixture of 1D gaussians

        mu1 = mean(x[x > mean(x)])  # initialize the means as well
        mu2 = mean(x[x < mean(x)])
        # and the probability that a spike is assigned to the first Gaussian
        p = mean(x > mean(x))

        # initialize matrix of log probabilities that each spike is assigned to the first
        # or second cluster
        logp = cp.zeros((nSpikes, 2), order='F')

        # do 50 pursuit iteration

        logP = cp.zeros(50)  # used to monitor the cost function

        # TODO: move_to_config - maybe...
        for k in range(50):
            # for each spike, estimate its probability to come from either Gaussian cluster
            logp[:, 0] = -1. / 2 * log(s1) - ((x - mu1)**2) / (2 * s1) + log(p)
            logp[:, 1] = -1. / 2 * log(s2) - (
                (x - mu2)**2) / (2 * s2) + log(1 - p)

            lMax = logp.max(axis=1)
            logp = logp - lMax[:, cp.
                               newaxis]  # subtract the max for floating point accuracy
            rs = cp.exp(logp)  # exponentiate the probabilities

            pval = cp.log(cp.sum(
                rs, axis=1)) + lMax  # get the normalizer and add back the max
            logP[k] = mean(
                pval)  # this is the cost function: we can monitor its increase

            rs = rs / cp.sum(
                rs,
                axis=1)[:,
                        cp.newaxis]  # normalize so that probabilities sum to 1

            p = mean(rs[:, 0])  # mean probability to be assigned to Gaussian 1
            # new estimate of mean of cluster 1 (weighted by "responsibilities")
            mu1 = cp.dot(rs[:, 0], x) / cp.sum(rs[:, 0])
            # new estimate of mean of cluster 2 (weighted by "responsibilities")
            mu2 = cp.dot(rs[:, 1], x) / cp.sum(rs[:, 1])

            s1 = cp.dot(rs[:, 0], (x - mu1)**2) / cp.sum(
                rs[:, 0])  # new estimates of variances
            s2 = cp.dot(rs[:, 1], (x - mu2)**2) / cp.sum(rs[:, 1])

            if (k >= 10) and (k % 2 == 0):
                # starting at iteration 10, we start re-estimating the pursuit direction
                # that is, given the Gaussian cluster assignments, and the mean and variances,
                # we re-estimate w
                # these equations follow from the model
                StS = cp.matmul(
                    clp.T,
                    clp *
                    (rs[:, 0] / s1 + rs[:, 1] / s2)[:, cp.newaxis]) / nSpikes
                StMu = cp.dot(
                    clp.T, rs[:, 0] * mu1 / s1 + rs[:, 1] * mu2 / s2) / nSpikes

                # this is the new estimate of the best pursuit direction
                w = cp.linalg.solve(StS.T, StMu)
                w = w / cp.sum(w**2)**0.5  # which we unit normalize
                x = cp.dot(clp, w)

        # these spikes are assigned to cluster 1
        ilow = rs[:, 0] > rs[:, 1]
        # the mean probability of spikes assigned to cluster 1
        plow = mean(rs[:, 0][ilow])
        phigh = mean(rs[:, 1][~ilow])  # same for cluster 2
        # the smallest cluster has this proportion of all spikes
        nremove = min(mean(ilow), mean(~ilow))

        # did this split fix the autocorrelograms?
        # compute the cross-correlogram between spikes in the putative new clusters
        ilow_cpu = cp.asnumpy(ilow)
        K, Qi, Q00, Q01, rir = ccg(ss[ilow_cpu], ss[~ilow_cpu], 500, dt)
        Q12 = (Qi / max(Q00, Q01)).min()  # refractoriness metric 1
        R = rir.min()  # refractoriness metric 2

        # if the CCG has a dip, don't do the split.
        # These thresholds are consistent with the ones from merges.
        # TODO: move_to_config (or at least a single constant so the are the same as the merges)
        if (Q12 < 0.25) and (R < 0.05):  # if both metrics are below threshold.
            nccg += 1  # keep track of how many splits were voided by the CCG criterion
            continue

        # now decide if the split would result in waveforms that are too similar
        # the reconstructed mean waveforms for putative cluster 1
        # c1 = cp.matmul(wPCA, cp.reshape((mean(clp0[ilow, :], 0), 3, -1), order='F'))
        c1 = cp.matmul(wPCA,
                       mean(clp0[ilow, :], 0).reshape((3, -1), order='F'))
        # the reconstructed mean waveforms for putative cluster 2
        # c2 = cp.matmul(wPCA, cp.reshape((mean(clp0[~ilow, :], 0), 3, -1), order='F'))
        c2 = cp.matmul(wPCA,
                       mean(clp0[~ilow, :], 0).reshape((3, -1), order='F'))

        cc = cp.corrcoef(c1.ravel(),
                         c2.ravel())  # correlation of mean waveforms
        n1 = sqrt(cp.sum(c1**2))  # the amplitude estimate 1
        n2 = sqrt(cp.sum(c2**2))  # the amplitude estimate 2

        r0 = 2 * abs((n1 - n2) / (n1 + n2))

        # if the templates are correlated, and their amplitudes are similar, stop the split!!!

        # TODO: move_to_config
        if (cc[0, 1] > 0.9) and (r0 < 0.2):
            continue

        # finaly criteria to continue with the split: if the split piece is more than 5% of all
        # spikes, if the split piece is more than 300 spikes, and if the confidences for
        # assigning spikes to # both clusters exceeds a preset criterion ccsplit
        # TODO: move_to_config
        if (nremove > 0.05) and (min(plow, phigh) > ccsplit) and (min(
                cp.sum(ilow), cp.sum(~ilow)) > 300):
            # one cluster stays, one goes
            Nfilt += 1

            # the templates for the splits have been estimated from PC coefficients

            # (DEV_NOTES) code below involves multiple CuPy arrays changing shape to accomodate
            # the extra cluster, this could potentially be done more efficiently?

            dWU = cp.concatenate(
                (cp.asarray(dWU), cp.zeros((*dWU.shape[:-1], 1), order='F')),
                axis=2)
            dWU[:, iC[:, iW[ik]], Nfilt - 1] = c2
            dWU[:, iC[:, iW[ik]], ik] = c1

            # the temporal components are therefore just the PC waveforms
            W = cp.asarray(W)
            W = cp.concatenate((W, cp.transpose(cp.atleast_3d(wPCA),
                                                (0, 2, 1))),
                               axis=1)
            assert W.shape[1] == Nfilt

            # copy the best channel from the original template
            iW = cp.asarray(iW)
            iW = cp.pad(iW, (0, (Nfilt - len(iW))), mode='constant')
            iW[Nfilt - 1] = iW[ik]
            assert iW.shape[0] == Nfilt

            # copy the provenance index to keep track of splits
            isplit = cp.asarray(isplit)
            isplit = cp.pad(isplit, (0, (Nfilt - len(isplit))),
                            mode='constant')
            isplit[Nfilt - 1] = isplit[ik]
            assert isplit.shape[0] == Nfilt

            st3[isp[ilow_cpu],
                1] = Nfilt - 1  # overwrite spike indices with the new index

            # copy similarity scores from the original
            simScore = cp.asarray(simScore)
            simScore = cp.pad(simScore, (0, (Nfilt - simScore.shape[0])),
                              mode='constant')
            simScore[:, Nfilt - 1] = simScore[:, ik]
            simScore[Nfilt - 1, :] = simScore[ik, :]
            # copy similarity scores from the original
            simScore[ik,
                     Nfilt - 1] = 1  # set the similarity with original to 1
            simScore[Nfilt - 1,
                     ik] = 1  # set the similarity with original to 1
            assert simScore.shape == (Nfilt, Nfilt)

            # copy neighbor template list from the original
            iNeigh = cp.asarray(iNeigh)
            iNeigh = cp.pad(iNeigh, ((0, 0), (0, (Nfilt - iNeigh.shape[1]))),
                            mode='constant')
            iNeigh[:, Nfilt - 1] = iNeigh[:, ik]
            assert iNeigh.shape[1] == Nfilt

            # copy neighbor channel list from the original
            iNeighPC = cp.asarray(iNeighPC)
            iNeighPC = cp.pad(iNeighPC,
                              ((0, 0), (0, (Nfilt - iNeighPC.shape[1]))),
                              mode='constant')
            iNeighPC[:, Nfilt - 1] = iNeighPC[:, ik]
            assert iNeighPC.shape[1] == Nfilt

            # try this cluster again
            # the cluster piece that stays at this index needs to be tested for splits again
            # before proceeding
            ik -= 1
            # the piece that became a new cluster will be tested again when we get to the end
            # of the list
            nsplits += 1  # keep track of how many splits we did
    #         pbar.update(ik)
    # pbar.close()

    logger.info(f'Finished splitting. Found {nsplits} splits, checked '
                f'{ik}/{Nfilt} clusters, nccg {nccg}')

    Nfilt = W.shape[1]  # new number of templates
    Nrank = 3
    Nchan = probe.Nchan
    Params = cp.array(
        [
            0, Nfilt, 0, 0, W.shape[0], Nnearest, Nrank, 0, 0, Nchan,
            NchanNear, nt0min, 0
        ],
        dtype=cp.float64)  # make a new Params to pass on parameters to CUDA

    # we need to re-estimate the spatial profiles

    # we get the time upsampling kernels again
    Ka, Kb = getKernels(params)
    # we run SVD
    W, U, mu = mexSVDsmall2(Params, dWU, W, iC, iW, Ka, Kb)

    # we re-compute similarity scores between templates
    WtW, iList = getMeWtW(W.astype(cp.float32), U.astype(cp.float32), Nnearest)
    # ir.iList = iList  # over-write the list of nearest templates

    isplit = simScore == 1  # overwrite the similarity scores of clusters with same parent
    simScore = WtW.max(axis=2)
    simScore[isplit] = 1  # 1 means they come from the same parent

    iNeigh = iList[:, :Nfilt]  # get the new neighbor templates
    iNeighPC = iC[:, iW[:Nfilt]]  # get the new neighbor channels

    # for Phy, we need to pad the spikes with zeros so the spikes are aligned to the center of
    # the window
    Wphy = cp.concatenate((cp.zeros((1 + nt0min, Nfilt, Nrank), order='F'), W),
                          axis=0)

    # ir.isplit = isplit  # keep track of origins for each cluster

    return Bunch(
        st3_s=st3,
        W_s=W,
        U_s=U,
        mu_s=mu,
        simScore_s=simScore,
        iNeigh_s=iNeigh,
        iNeighPC_s=iNeighPC,
        Wphy=Wphy,
        iList=iList,
        isplit=isplit,
    )
# benchmark method
num_times = int(snakemake.config['benchmark']['num_times'])

try:
    times = timeit.repeat("cp.asnumpy(cp.corrcoef(cp.array(mat)))",
                          globals=globals(),
                          number=1,
                          repeat=num_times)
except:
    # if graphics card runs out of memory, return empty result
    times = [np.nan]

# save timings
with open(snakemake.output['timings'], 'w') as fp:
    entry = [
        'Pearson', 'Python', 'cupy.corrcoef', snakemake.wildcards['nrows'],
        str(min(times))
    ]
    fp.write(", ".join(entry) + "\n")

# store correlation matrix result for comparison
try:
    cor_mat = cp.asnumpy(cp.corrcoef(cp.array(mat)))
    cor_mat[np.tril_indices_from(cor_mat)] = np.nan
except:
    nrows = int(snakemake.wildcards['nrows'])
    cor_mat = np.repeat(np.nan, nrows**2).reshape((nrows, nrows))

np.save(snakemake.output['cor_mat'], cor_mat)