Пример #1
0
 def test_kcsd1d_estimate(self, cv_params={}):
     self.test_params.update(cv_params)
     result = KCSD1D(self.ele_pos, self.pots, **self.test_params)
     result.cross_validate()
     vals = result.values()
     true_csd = self.csd_profile(result.estm_x, 42)
     rms = np.linalg.norm(np.array(vals[:, 0]) - true_csd)
     rms /= np.linalg.norm(true_csd)
     self.assertLess(rms, 0.5, msg='RMS between trueCSD and estimate > 0.5')
Пример #2
0
 def test_lcurve(self):
     result = KCSD1D(self.ele_pos, self.pots, **self.test_params)
     result.L_curve()
     vals = result.values()
     pvals = result.values('POT')
     true_csd = self.csd_profile(result.estm_x, 42)
     rms = np.linalg.norm(np.array(vals[:, 0]) - true_csd)
     rms /= np.linalg.norm(true_csd)
     self.assertLess(rms, 0.5, msg='RMS between trueCSD and estimate > 0.5')
Пример #3
0
def stability_M(n_src, total_ele, ele_pos, pots, R_init=0.23):
    """
    Investigates stability of reconstruction for different number of basis
    sources

    Parameters
    ----------
    n_src: int
        Number of basis sources.
    total_ele: int
        Number of electrodes.
    ele_pos: numpy array
        Electrodes positions.
    pots: numpy array
        Values of potentials at ele_pos.
    R_init: float
        Initial value of R parameter - width of basis source
        Default: 0.23.

    Returns
    -------
    obj_all: class object
    eigenvalues: numpy array
        Eigenvalues of k_pot matrix.
    eigenvectors: numpy array
        Eigen vectors of k_pot matrix.
    """
    obj_all = []
    eigenvectors = np.zeros((len(n_src), total_ele, total_ele))
    eigenvalues = np.zeros((len(n_src), total_ele))
    for i, value in enumerate(n_src):
        pots = pots.reshape((len(ele_pos), 1))
        obj = KCSD1D(ele_pos,
                     pots,
                     src_type='gauss',
                     sigma=0.3,
                     h=0.25,
                     gdx=0.01,
                     n_src_init=n_src[i],
                     ext_x=0,
                     xmin=0,
                     xmax=1,
                     R_init=R_init)
        try:
            eigenvalue, eigenvector = np.linalg.eigh(
                obj.k_pot + obj.lambd * np.identity(obj.k_pot.shape[0]))
        except LinAlgError:
            raise LinAlgError('EVD is failing - try moving the electrodes'
                              'slightly')
        idx = eigenvalue.argsort()[::-1]
        eigenvalues[i] = eigenvalue[idx]
        eigenvectors[i] = eigenvector[:, idx]
        obj_all.append(obj)
    return obj_all, eigenvalues, eigenvectors
Пример #4
0
def modified_bases(k, pots, ele_pos, n_src, title, h=0.25, sigma=0.3,
                   gdx=0.035, ext_x=0, xmin=0, xmax=1):
    '''
    Parameters
    ----------
    k: object of the class ValidateKCSD1D
    pots: numpy array
        Potentials measured (calculated) on electrodes.
    ele_pos: numpy array
        Locations of electrodes.
    n_src: int
        Number of basis sources.
    title: string
        Title of the plot.
    h: float
        Thickness of analyzed cylindrical slice.
        Default: 0.25.
    sigma: float
        Space conductance of the medium.
        Default: 0.3.
    gdx: float
        Space increments in the estimation space.
        Default: 0.035.
    ext_x: float
        Length of space extension: xmin-ext_x ... xmax+ext_x.
        Default: 0.
    xmin: float
        Boundaries for CSD estimation space.
    xmax: float
        boundaries for CSD estimation space.

    Returns
    -------
    None
    '''
    pots = pots.reshape((len(ele_pos), 1))
    obj_m = KCSD1D(ele_pos, pots, src_type='gauss', sigma=sigma, h=h, gdx=gdx,
                   n_src_init=n_src, ext_x=ext_x, xmin=xmin, xmax=xmax)
    obj_m.cross_validate(Rs=np.arange(0.2, 0.5, 0.1))
    est_csd = obj_m.values('CSD')
    test_csd = csd_profile(obj_m.estm_x, [R, MU])
    rms = val.calculate_rms(test_csd, est_csd)
    titl = "Lambda: %0.2E; R: %0.2f; RMS_Error: %0.2E;" % (obj_m.lambd,
                                                           obj_m.R, rms)
    fig = k.make_plot(csd_at, true_csd, obj_m, est_csd, ele_pos, pots, titl)
    save_as = (SAVE_PATH)
    fig.savefig(os.path.join(SAVE_PATH, save_as + '/' + title + '.png'))
    plt.close()
    ss = SpectralStructure(obj_m)
    eigenvectors, eigenvalues = ss.evd()
    plot_eigenvalues(eigenvalues, SAVE_PATH, title)
    plot_eigenvectors(eigenvectors, SAVE_PATH, title)
    plot_k_interp_cross_v(obj_m.k_interp_cross, eigenvectors, SAVE_PATH, title)
Пример #5
0
def pots_scan(n_src,
              ele_lims,
              true_csd_xlims,
              total_ele,
              ele_pos,
              R_init=0.23):
    """
    Investigates kCSD reconstructions for unitary potential on different
    electrodes

    Parameters
    ----------
    n_src: int
        Number of basis sources.
    ele_lims: list
        Boundaries for electrodes placement.
    true_csd_xlims: list
        Boundaries for ground truth space.
    total_ele: int
        Number of electrodes.
    ele_pos: numpy array
        Electrodes positions.

    Returns
    -------
    obj_all: class object
    eigenvalues: numpy array
        Eigenvalues of k_pot matrix.
    eigenvectors: numpy array
        Eigen vectors of k_pot matrix.
    """
    obj_all = []
    est_csd = []
    for i, value in enumerate(ele_pos):
        pots = np.zeros(len(ele_pos))
        pots[i] = 1
        pots = pots.reshape((len(ele_pos), 1))
        obj = KCSD1D(ele_pos,
                     pots,
                     src_type='gauss',
                     sigma=0.3,
                     h=0.25,
                     gdx=0.01,
                     n_src_init=n_src,
                     ext_x=0,
                     xmin=0,
                     xmax=1,
                     R_init=R_init)
        est_csd.append(obj.values('CSD'))

        obj_all.append(obj)
    return obj_all, est_csd
Пример #6
0
def do_kcsd(ele_pos, pots, **params):
    """
    Function that calls the KCSD1D module
    """
    num_ele = len(ele_pos)
    pots = pots.reshape(num_ele, 1)
    ele_pos = ele_pos.reshape(num_ele, 1)
    Lamb = [-9, -3]
    lambdas = np.logspace(Lamb[0], Lamb[1], 50, base=10)
    k = KCSD1D(ele_pos, pots, **params)
    if name == 'lc':
        k.L_curve(lambdas=lambdas, Rs=ery)
    else:
        k.cross_validate(Rs=ery, lambdas=lambdas)
    est_csd = k.values('CSD')
    est_pot = k.values('POT')
    return k, est_csd, est_pot
Пример #7
0
def do_kcsd(i, ele_pos, pots, **params):
    """
    Function that calls the KCSD1D module
    """
    num_ele = len(ele_pos)
    pots = pots.reshape(num_ele, 1)
    ele_pos = ele_pos.reshape(num_ele, 1)
    k = KCSD1D(ele_pos, pots, **params)
    noreg_csd = k.values('CSD')
    k.cross_validate(Rs=Rs, lambdas=lambdas)
    errsy = k.errs
    LandR[1,0,i] = k.lambd
    LandR[1,1,i] = k.R
    est_csd_cv = k.values('CSD')
    k.L_curve(lambdas=lambdas, Rs=Rs)
    LandR[0,0,i] = k.lambd
    LandR[0,1,i] = k.R
    est_csd = k.values('CSD')
    est_pot = k.values('POT')
    return k, est_csd, est_pot, noreg_csd, errsy, est_csd_cv
Пример #8
0
def modified_bases(val,
                   pots,
                   ele_pos,
                   n_src,
                   title=None,
                   h=0.25,
                   sigma=0.3,
                   gdx=0.01,
                   ext_x=0,
                   xmin=0,
                   xmax=1,
                   R=0.2,
                   MU=0.25,
                   method='cross-validation',
                   Rs=None,
                   lambdas=None):
    '''
    Parameters
    ----------
    val: object of the class ValidateKCSD1D
    pots: numpy array
        Potentials measured (calculated) on electrodes.
    ele_pos: numpy array
        Locations of electrodes.
    n_src: int
        Number of basis sources.
    title: string
        Title of the plot.
    h: float
        Thickness of analyzed cylindrical slice.
        Default: 0.25.
    sigma: float
        Space conductance of the medium.
        Default: 0.3.
    gdx: float
        Space increments in the estimation space.
        Default: 0.035.
    ext_x: float
        Length of space extension: xmin-ext_x ... xmax+ext_x.
        Default: 0.
    xmin: float
        Boundaries for CSD estimation space.
    xmax: float
        boundaries for CSD estimation space.
    R: float
        Thickness of the groundtruth source.
        Default: 0.2.
    MU: float
        Central position of Gaussian source
        Default: 0.25.
    method: string
        Determines the method of regularization.
        Default: cross-validation.
    Rs: numpy 1D array
        Basis source parameter for crossvalidation.
        Default: None.
    lambdas: numpy 1D array
        Regularization parameter for crossvalidation.
        Default: None.

    Returns
    -------
    obj_m: object of the class KCSD1D
    '''
    pots = pots.reshape((len(ele_pos), 1))
    obj_m = KCSD1D(ele_pos,
                   pots,
                   src_type='gauss',
                   sigma=sigma,
                   h=h,
                   gdx=gdx,
                   n_src_init=n_src,
                   ext_x=ext_x,
                   xmin=xmin,
                   xmax=xmax)
    if method == 'cross-validation':
        obj_m.cross_validate(Rs=Rs, lambdas=lambdas)
    elif method == 'L-curve':
        obj_m.L_curve(Rs=Rs, lambdas=lambdas)
    est_csd = obj_m.values('CSD')
    test_csd = csd_profile(obj_m.estm_x, [R, MU])
    rms = val.calculate_rms(test_csd, est_csd)
    #    titl = "Lambda: %0.2E; R: %0.2f; RMS_Error: %0.2E;" % (obj_m.lambd,
    #                                                           obj_m.R, rms)
    #    fig = k.make_plot(csd_at, true_csd, obj_m, est_csd, ele_pos, pots, titl)
    #    save_as = (SAVE_PATH)
    #    fig.savefig(os.path.join(SAVE_PATH, save_as + '/' + title + '.png'))
    #    plt.close()
    #    ss = SpectralStructure(obj_m)
    #    eigenvectors, eigenvalues = ss.evd()
    return obj_m
Пример #9
0
    def make_reconstruction(self,
                            csd_profile,
                            csd_seed,
                            noise=0,
                            nr_broken_ele=None,
                            Rs=None,
                            lambdas=None,
                            method='cross-validation'):
        """
        Makes the whole kCSD reconstruction.

        Parameters
        ----------
        csd_profile: function
            function to produce csd profile
        csd_seed: int
            Seed for random generator to choose random CSD profile.
        noise: float
            Determines the level of noise in the data.
            Default: 0.
        nr_broken_ele: int
            How many electrodes are broken (excluded from analysis)
            Default: None.
        Rs: numpy 1D array
            Basis source parameter for crossvalidation.
            Default: None.
        lambdas: numpy 1D array
            Regularization parameter for crossvalidation.
            Default: None.
        method: string
            Determines the method of regularization.
            Default: cross-validation.

        Returns
        -------
        rms: float
            Error of reconstruction.
        point_error: numpy array
            Error of reconstruction calculated at every point of reconstruction
            space.
        """
        ele_pos, pots = self.electrode_config(csd_profile, csd_seed,
                                              self.total_ele, self.ele_lims,
                                              self.h, self.sigma, noise,
                                              nr_broken_ele)

        k = KCSD1D(ele_pos,
                   pots,
                   h=self.h,
                   gdx=self.est_xres,
                   xmax=np.max(self.kcsd_xlims),
                   xmin=np.min(self.kcsd_xlims),
                   sigma=self.sigma,
                   n_src_init=self.n_src_init)
        if method == 'cross-validation':
            k.cross_validate(Rs=Rs, lambdas=lambdas)
        elif method == 'L-curve':
            k.L_curve(Rs=Rs, lambdas=lambdas)
        else:
            raise ValueError('Invalid value of reconstruction method,'
                             'pass either cross-validation or L-curve')
        est_csd = k.values('CSD')
        test_csd = csd_profile(k.estm_x, csd_seed)
        rms = self.calculate_rms(test_csd, est_csd[:, 0])
        point_error = self.calculate_point_error(test_csd, est_csd[:, 0])
        return rms, point_error
Пример #10
0
        Kt = gpcsd_model.temporal_cov_list[0].compute_Kt(
        ) + gpcsd_model.temporal_cov_list[1].compute_Kt()
        Ks = gpcsd_model.spatial_cov.compKphi_1d(
            gpcsd_model.R['value']) + 1e-8 * np.eye(nx)
        Qs, Qt, Dvec = comp_eig_D(Ks, Kt, gpcsd_model.sig2n['value'])
        cov_probe['Qs'] = Qs
        cov_probe['Qt'] = Qt
        cov_probe['Dvec'] = Dvec

        # Compute empirical mean of CSD estimated by GPCSD as evoked response
        gpcsd_model.predict(z, trial_pred_t)
        evoked_probe['gpcsd'] = np.mean(gpcsd_model.csd_pred, 2)

        # kCSD estimation of evoked response for comparison
        kcsd_evoked_model = KCSD1D(x,
                                   lfp_trial_pred_evoked,
                                   gdx=1.,
                                   h=gpcsd_model.R['value'])
        kcsd_evoked_model.cross_validate(Rs=np.linspace(100, 800, 15),
                                         lambdas=np.logspace(1,
                                                             -15,
                                                             25,
                                                             base=10.))
        evoked_probe['kcsd'] = kcsd_evoked_model.values()

        with open('%s/results/gpcsd_evoked_%s.pkl' % (root_path, probe_name),
                  'wb') as f:
            pickle.dump(evoked_probe, f)
        evoked[probe_name] = evoked_probe

        with open('%s/results/gpcsd_cov_%s.pkl' % (root_path, probe_name),
                  'wb') as f:
Пример #11
0
    gpcsd_model = GPCSD1D(lfp[:, :, 50:], x, t)
    gpcsd_model.R['value'] = R_true
    gpcsd_model.sig2n['value'] = sig2n_true
    gpcsd_model.spatial_cov.params['ell']['value'] = ellSE_true
    gpcsd_model.temporal_cov_list[0].params['ell']['value'] = elltSE_true
    gpcsd_model.temporal_cov_list[0].params['sigma2']['value'] = sig2tSE_true
    gpcsd_model.temporal_cov_list[1].params['ell']['value'] = elltM_true
    gpcsd_model.temporal_cov_list[1].params['sigma2']['value'] = sig2tM_true

print(gpcsd_model)
gpcsd_model.predict(xshort, t)

# %% kCSD estimation
# use first five trials concatenated for estimating parameters (for computational reasons)
kcsd_model = KCSD1D(x,
                    lfp[:, :, :5].reshape((nx, -1)),
                    gdx=deltax / 20,
                    h=R_true)
kcsd_model.cross_validate(Rs=np.linspace(100, 1000, 15))

kcsd_R = kcsd_model.R
kcsd_lambda = kcsd_model.lambd
kcsd_values = []
# Predict on test set
for i in range(ntrials):
    kcsd_model_tmp = KCSD1D(x,
                            lfp[:, :, 50 + i].squeeze(),
                            gdx=deltax / 20,
                            h=R_true,
                            R_init=kcsd_R,
                            lambd=kcsd_lambda)
    kcsd_values_tmp = kcsd_model_tmp.values()
Пример #12
0
# %% GPCSD fitting and predictions
for k in lfp.keys():
    print('\nStarting %s' % k)
    gpcsd[k].fit(n_restarts=10)
    gpcsd[k].predict(z, t)

# %% Print GPCSD fitting results
for k in lfp.keys():
    print(gpcsd[k])

# %% kCSD estimation
from kcsd import KCSD1D  # https://github.com/Neuroinflab/kCSD-python/releases/tag/v2.0
start_t = time.process_time()
est_kcsd = {}
for k in lfp.keys():
    kcsd_tmp = KCSD1D(x, lfp[k], gdx=deltaz, h=R_true)
    kcsd_tmp.cross_validate(Rs=np.linspace(100, 800, 15),
                            lambdas=np.logspace(1, -15, 25, base=10.))
    est_kcsd[k] = kcsd_tmp.values()
end_t = time.process_time()
print('kCSD took %0.2f s (per dataset, with cross-validation)' %
      ((end_t - start_t) / len(est_kcsd)))

# %% Visualize results
vmlfp = np.amax(np.abs(lfp['noiseless']))
vmcsd = np.amax(np.abs(normalize(csd_true)))

plt.rcParams.update({'font.size': 12})

f = plt.figure(figsize=(16, 8))
ax = plt.subplot(251)