def test_gamma_map_vol_sphere():
    """Gamma MAP with a sphere forward and volumic source space."""
    evoked = read_evokeds(fname_evoked, condition=0, baseline=(None, 0),
                          proj=False)
    evoked.resample(50, npad=100)
    evoked.crop(tmin=0.1, tmax=0.16)  # crop to window around peak

    cov = read_cov(fname_cov)
    cov = regularize(cov, evoked.info, rank=None)

    info = evoked.info
    sphere = mne.make_sphere_model(r0=(0., 0., 0.), head_radius=0.080)
    src = mne.setup_volume_source_space(subject=None, pos=30., mri=None,
                                        sphere=(0.0, 0.0, 0.0, 80.0),
                                        bem=None, mindist=5.0,
                                        exclude=2.0)
    fwd = mne.make_forward_solution(info, trans=None, src=src, bem=sphere,
                                    eeg=False, meg=True)

    alpha = 0.5
    pytest.raises(ValueError, gamma_map, evoked, fwd, cov, alpha,
                  loose=0, return_residual=False)

    pytest.raises(ValueError, gamma_map, evoked, fwd, cov, alpha,
                  loose=0.2, return_residual=False)

    stc = gamma_map(evoked, fwd, cov, alpha, tol=1e-4,
                    xyz_same_gamma=False, update_mode=2,
                    return_residual=False)

    assert_array_almost_equal(stc.times, evoked.times, 5)

    # Compare orientation obtained using fit_dipole and gamma_map
    # for a simulated evoked containing a single dipole
    stc = mne.VolSourceEstimate(50e-9 * np.random.RandomState(42).randn(1, 4),
                                vertices=stc.vertices[:1],
                                tmin=stc.tmin,
                                tstep=stc.tstep)
    evoked_dip = mne.simulation.simulate_evoked(fwd, stc, info, cov, nave=1e9,
                                                use_cps=True)

    dip_gmap = gamma_map(evoked_dip, fwd, cov, 0.1, return_as_dipoles=True)

    amp_max = [np.max(d.amplitude) for d in dip_gmap]
    dip_gmap = dip_gmap[np.argmax(amp_max)]
    assert (dip_gmap[0].pos[0] in src[0]['rr'][stc.vertices])

    dip_fit = mne.fit_dipole(evoked_dip, cov, sphere)[0]
    assert (np.abs(np.dot(dip_fit.ori[0], dip_gmap.ori[0])) > 0.99)
Ejemplo n.º 2
0
def add_volume_stcs(stc1, stc2):
    """Adds two SourceEstimates together, allowing for different vertices."""
    vertices = np.union1d(stc1.vertices, stc2.vertices)

    assert stc1.data.shape[1] == stc2.data.shape[1]
    assert stc1.tmin == stc2.tmin
    assert stc1.tstep == stc2.tstep

    data = np.zeros((len(vertices), stc1.data.shape[1]))
    for i, vert in enumerate(vertices):
        if vert in stc1.vertices[0]:
            data[[i]] += stc1.data[stc1.vertices[0] == vert]
        if vert in stc2.vertices[0]:
            data[[i]] += stc2.data[stc2.vertices[0] == vert]

    return mne.VolSourceEstimate(data, [vertices],
                                 tmin=stc1.tmin,
                                 tstep=stc1.tstep)
Ejemplo n.º 3
0
def compute_fwds_stc(position, perts, sphere):
    pos = position.copy()
    pos['rr'] = mne.transforms.apply_trans(head_mri_t, position['rr'])  # invert back to mri
    pos['nn'] = mne.transforms.apply_trans(head_mri_t, position['nn'])
    src = mne.setup_volume_source_space(subject=subject, pos=pos, mri=None,
                                        sphere=(0, 0, 0, 90), bem=None,
                                        surface=None, mindist=1.0, exclude=0.0,
                                        subjects_dir=None, volume_label=None,
                                        add_interpolator=True, verbose=None)
    fwd_pert = make_pert_forward_solution(raw_fname, trans=trans, src=src, bem=sphere, perts=perts,
                                          meg=True, eeg=False, mindist=1.0, n_jobs=1)
    fwd = mne.make_forward_solution(raw_fname, trans=trans, src=src, bem=sphere,
                                    meg=True, eeg=False, mindist=1.0, n_jobs=1)
    fwd_fixed = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                             use_cps=True)
    fwd_pert_fixed = mne.convert_forward_solution(fwd_pert, surf_ori=True, force_fixed=True,
                                                  use_cps=True)

    amplitude = 1e-5
    stc = mne.VolSourceEstimate(amplitude * np.eye(1), [[0]], tmin=0., tstep=1)
    return fwd_fixed, fwd_pert_fixed, stc
Ejemplo n.º 4
0
def fit_dips(min_rad, max_rad, nn, sphere, perts, sourcenorm):
    testsources = dict(rr=[], nn=[])
    nsources = max_rad - min_rad + 1
    vertices = np.zeros((nsources, 1))
    for i in range(min_rad, max_rad + 1):
        ex, ey, ez = sourcenorm[0], sourcenorm[1], sourcenorm[2]
        source = [.001*i*ex, .001*i*ey, .001*i*ez]
        normal = [nn[0], nn[1], nn[2]]
        testsources['rr'].append(source)
        testsources['nn'].append(normal)
        vertices[i - min_rad] = i

    pos = dict(rr=[0], nn=[0])
    pos['rr'] = mne.transforms.apply_trans(head_mri_t, testsources['rr'])  # invert back to mri
    pos['nn'] = mne.transforms.apply_trans(head_mri_t, testsources['nn'])
    src = mne.setup_volume_source_space(subject=subject, pos=pos, mri=None,
                                        sphere=(0, 0, 0, 90), bem=None,
                                        surface=None, mindist=1.0, exclude=0.0,
                                        subjects_dir=None, volume_label=None,
                                        add_interpolator=True, verbose=None)
    fwd_pert = make_pert_forward_solution(raw_fname, trans=trans, src=src, bem=sphere, perts=perts,
                                          meg=True, eeg=False, mindist=1.0, n_jobs=1)
    fwd = mne.make_forward_solution(raw_fname, trans=trans, src=src, bem=sphere,
                                    meg=True, eeg=False, mindist=1.0, n_jobs=1)
    fwd_fixed = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                             use_cps=True)
    fwd_pert_fixed = mne.convert_forward_solution(fwd_pert, surf_ori=True, force_fixed=True,
                                                  use_cps=True)

    amplitude = 1e-5
    source = np.eye(nsources) * amplitude

    stc = mne.VolSourceEstimate(source, vertices, tmin=0., tstep=1)
    evoked = mne.simulation.simulate_evoked(fwd_fixed, stc, info, cov, use_cps=True,
                                            iir_filter=None)
    evoked_pert = mne.simulation.simulate_evoked(fwd_pert_fixed, stc, info, cov, use_cps=True,
                                                 iir_filter=None)
    dip_fit_long = mne.fit_dipole(evoked, cov_fname, sphere, trans)[0]
    dip_fit_pert = mne.fit_dipole(evoked_pert, cov_fname, sphere, trans)[0]
    return dip_fit_long, dip_fit_pert, testsources
        continue

    ###############################################################################
    # Create dist stc from simulated data
    ###############################################################################

    vert_sel = sel['vertex'].to_numpy()
    data_dist_sel = sel['dist'].to_numpy()
    data_eval_sel = sel['eval'].to_numpy()
    data_ori_sel = sel['ori_error'].to_numpy()

    data_dist = np.zeros(shape=(vertno.shape[0], 1))
    data_dist[vert_sel, 0] = data_dist_sel + 0.001
    vstc_dist = mne.VolSourceEstimate(data=data_dist,
                                      vertices=vertno,
                                      tmin=0,
                                      tstep=1 / info['sfreq'],
                                      subject='sample')

    data_eval = np.zeros(shape=(vertno.shape[0], 1))
    data_eval[vert_sel, 0] = data_eval_sel + 0.001
    vstc_eval = mne.VolSourceEstimate(data=data_eval,
                                      vertices=vertno,
                                      tmin=0,
                                      tstep=1 / info['sfreq'],
                                      subject='sample')

    data_ori = np.zeros(shape=(vertno.shape[0], 1))
    data_ori[vert_sel, 0] = data_ori_sel + 0.001
    vstc_ori = mne.VolSourceEstimate(data=data_ori,
                                     vertices=vertno,
Ejemplo n.º 6
0
def test_mxne_vol_sphere():
    """(TF-)MxNE with a sphere forward and volumic source space"""
    evoked = read_evokeds(fname_data, condition=0, baseline=(None, 0))
    evoked.crop(tmin=-0.05, tmax=0.2)
    cov = read_cov(fname_cov)

    evoked_l21 = evoked.copy()
    evoked_l21.crop(tmin=0.081, tmax=0.1)

    info = evoked.info
    sphere = mne.make_sphere_model(r0=(0., 0., 0.), head_radius=0.080)
    src = mne.setup_volume_source_space(subject=None,
                                        pos=15.,
                                        mri=None,
                                        sphere=(0.0, 0.0, 0.0, 80.0),
                                        bem=None,
                                        mindist=5.0,
                                        exclude=2.0)
    fwd = mne.make_forward_solution(info,
                                    trans=None,
                                    src=src,
                                    bem=sphere,
                                    eeg=False,
                                    meg=True)

    alpha = 80.
    assert_raises(ValueError,
                  mixed_norm,
                  evoked,
                  fwd,
                  cov,
                  alpha,
                  loose=0.0,
                  return_residual=False,
                  maxit=3,
                  tol=1e-8,
                  active_set_size=10)

    assert_raises(ValueError,
                  mixed_norm,
                  evoked,
                  fwd,
                  cov,
                  alpha,
                  loose=0.2,
                  return_residual=False,
                  maxit=3,
                  tol=1e-8,
                  active_set_size=10)

    # irMxNE tests
    stc = mixed_norm(evoked_l21,
                     fwd,
                     cov,
                     alpha,
                     n_mxne_iter=1,
                     maxit=30,
                     tol=1e-8,
                     active_set_size=10)
    assert_true(isinstance(stc, VolSourceEstimate))
    assert_array_almost_equal(stc.times, evoked_l21.times, 5)

    # Compare orientation obtained using fit_dipole and gamma_map
    # for a simulated evoked containing a single dipole
    stc = mne.VolSourceEstimate(50e-9 * np.random.RandomState(42).randn(1, 4),
                                vertices=stc.vertices[:1],
                                tmin=stc.tmin,
                                tstep=stc.tstep)
    evoked_dip = mne.simulation.simulate_evoked(fwd,
                                                stc,
                                                info,
                                                cov,
                                                nave=1e9,
                                                use_cps=True)

    dip_mxne = mixed_norm(evoked_dip,
                          fwd,
                          cov,
                          alpha=80,
                          n_mxne_iter=1,
                          maxit=30,
                          tol=1e-8,
                          active_set_size=10,
                          return_as_dipoles=True)

    amp_max = [np.max(d.amplitude) for d in dip_mxne]
    dip_mxne = dip_mxne[np.argmax(amp_max)]
    assert_true(dip_mxne.pos[0] in src[0]['rr'][stc.vertices])

    dip_fit = mne.fit_dipole(evoked_dip, cov, sphere)[0]
    assert_true(np.abs(np.dot(dip_fit.ori[0], dip_mxne.ori[0])) > 0.99)

    # Do with TF-MxNE for test memory savings
    alpha = 60.  # overall regularization parameter
    l1_ratio = 0.01  # temporal regularization proportion

    stc, _ = tf_mixed_norm(evoked,
                           fwd,
                           cov,
                           maxit=3,
                           tol=1e-4,
                           tstep=16,
                           wsize=32,
                           window=0.1,
                           alpha=alpha,
                           l1_ratio=l1_ratio,
                           return_residual=True)
    assert_true(isinstance(stc, VolSourceEstimate))
    assert_array_almost_equal(stc.times, evoked.times, 5)
Ejemplo n.º 7
0
del fwd

##############################################################################
# Compute label time series and do envelope correlation
# -----------------------------------------------------

epochs.apply_hilbert()  # faster to do in sensor space
stcs = apply_lcmv_epochs(epochs, filters, return_generator=True)
corr = envelope_correlation(stcs, verbose=True)

##############################################################################
# Compute the degree and plot it
# ------------------------------

degree = mne.connectivity.degree(corr, 0.15)
stc = mne.VolSourceEstimate(degree, src[0]['vertno'], 0, 1, 'bst_resting')
brain = stc.plot(src,
                 clim=dict(kind='percent', lims=[75, 85, 95]),
                 colormap='gnuplot',
                 subjects_dir=subjects_dir,
                 mode='glass_brain')

##############################################################################
# References
# ----------
# .. [1] Hipp JF, Hawellek DJ, Corbetta M, Siegel M, Engel AK (2012)
#        Large-scale cortical correlation structure of spontaneous
#        oscillatory activity. Nature Neuroscience 15:884–890
# .. [2] Khan S et al. (2018). Maturation trajectories of cortical
#        resting-state networks depend on the mediating frequency band.
#        Neuroimage 174:57–68
def add_source_to_raw(raw, fwd_disc_true, signal_vertex, signal_freq,
                      trial_length, n_trials, source_type):
    """
    Add a new simulated dipole source to an existing raw. Operates on a copy of
    the raw.

    Parameters:
    -----------
    raw : instance of Raw
        The raw data to add a new source to.
    fwd_disc_true : instance of mne.Forward
        The forward operator for the discrete source space created with
        the true transformation file.
    signal_vertex : int
        The vertex where signal dipole is placed.
    signal_freq : float
        The frequency of the signal.
    trial_length : float
        Length of a single trial in samples.
    n_trials : int
        Number of trials to create.
    source_type : 'chirp' | 'random'
        Type of source signal to add.

    Returns:
    --------
    raw : instance of Raw
        The summation of the original raw and the simulated source.
    stc_signal : instance of SourceEstimate
        Source time courses of the new signal.
    """
    sfreq = raw.info['sfreq']
    trial_length = int((config.tmax - config.tmin) * sfreq)
    times = np.arange(trial_length) / sfreq + config.tmin

    src = fwd_disc_true['src']
    signal_vert = src[0]['vertno'][signal_vertex]
    data = np.zeros(len(times))
    signal_part = times >= 0

    if source_type == 'chirp':
        data[signal_part] += generate_signal(times[signal_part],
                                             signal_freq,
                                             phase=0.5)
    elif source_type == 'random':
        data[signal_part] += generate_random(times[signal_part])

    vertices = np.array([signal_vert])
    stc_signal = mne.VolSourceEstimate(data=data[np.newaxis, :],
                                       vertices=vertices,
                                       tmin=times[0],
                                       tstep=np.diff(times[:2])[0],
                                       subject='sample')
    raw_signal = simulate_raw_mne(raw.info,
                                  stc_signal,
                                  trans=None,
                                  src=None,
                                  bem=None,
                                  forward=fwd_disc_true)
    raw_signal = mne.concatenate_raws(
        [raw_signal.copy() for _ in range(n_trials)])

    raw = raw.copy()
    raw_picks = mne.pick_types(raw.info, meg=True, eeg=False)
    raw._data[raw_picks] += raw_signal._data[raw_picks]

    return raw, stc_signal
def simulate_raw(info,
                 fwd_disc_true,
                 signal_vertex,
                 signal_freq,
                 n_trials,
                 noise_multiplier,
                 random_state,
                 n_noise_dipoles,
                 er_raw,
                 fn_stc_signal=None,
                 fn_simulated_raw=None,
                 fn_report_h5=None):
    """
    Simulate raw time courses for two dipoles with frequencies
    given by signal_freq1 and signal_freq2. Noise dipoles are
    placed randomly in the whole cortex.

    Parameters:
    -----------
    info : instance of Info | instance of Raw
        The channel information to use for simulation.
    fwd_disc_true : instance of mne.Forward
        The forward operator for the discrete source space created with
        the true transformation file.
    signal_vertex : int
        The vertex where signal dipole is placed.
    signal_freq : float
        The frequency of the signal.
    n_trials : int
        Number of trials to create.
    noise_multiplier : float
        Multiplier for the noise dipoles. For noise_multiplier equal to one
        the signal and noise dipoles have the same magnitude.
    random_state : None | int | instance of RandomState
        If random_state is an int, it will be used as a seed for RandomState.
        If None, the seed will be obtained from the operating system (see
        RandomState for details). Default is None.
    n_noise_dipoles : int
        The number of noise dipoles to place within the volume.
    er_raw : instance of Raw
        Empty room measurement to be used as sensor noise.
    fn_stc_signal : None | string
        Path where the signal source time courses are to be saved. If None the file is not saved.
    fn_simulated_raw : None | string
        Path where the raw data is to be saved. If None the file is not saved.
    fn_report_h5 : None | string
        Path where the .h5 file for the report is to be saved.

    Returns:
    --------
    raw : instance of Raw
        Simulated raw file.
    stc_signal : instance of SourceEstimate
        Source time courses of the signal.
    """

    sfreq = info['sfreq']
    trial_length = int((config.tmax - config.tmin) * sfreq)
    times = np.arange(trial_length) / sfreq + config.tmin

    ###############################################################################
    # Simulate a single signal dipole source as signal
    ###############################################################################

    # TODO: I think a discrete source space was used because mne.simulate_raw did not take volume source spaces -> test
    src = fwd_disc_true['src']
    signal_vert = src[0]['vertno'][signal_vertex]
    data = np.asarray([generate_signal(times, freq=signal_freq)])
    vertices = np.array([signal_vert])
    stc_signal = mne.VolSourceEstimate(data=data,
                                       vertices=[vertices],
                                       tmin=times[0],
                                       tstep=np.diff(times[:2])[0],
                                       subject='sample')
    if fn_stc_signal is not None:
        set_directory(op.dirname(fn_stc_signal))
        stc_signal.save(fn_stc_signal)

    ###############################################################################
    # Create trials of simulated data
    ###############################################################################

    # select n_noise_dipoles entries from rr and their corresponding entries from nn
    raw_list = []

    for i in range(n_trials):
        # Simulate random noise dipoles
        stc_noise = simulate_sparse_stc(src,
                                        n_noise_dipoles,
                                        times,
                                        data_fun=generate_random,
                                        random_state=random_state,
                                        labels=None)

        # Project to sensor space
        stc = add_volume_stcs(stc_signal, noise_multiplier * stc_noise)

        raw = simulate_raw_mne(info,
                               stc,
                               trans=None,
                               src=None,
                               bem=None,
                               forward=fwd_disc_true)

        raw_list.append(raw)
        print('%02d/%02d' % (i + 1, n_trials))

    raw = mne.concatenate_raws(raw_list)

    # Use empty room noise as sensor noise
    raw_picks = mne.pick_types(raw.info, meg=True, eeg=False)
    er_raw_picks = mne.pick_types(er_raw.info, meg=True, eeg=False)
    raw._data[raw_picks] += er_raw._data[er_raw_picks, :len(raw.times)]

    ###############################################################################
    # Save everything
    ###############################################################################

    if fn_simulated_raw is not None:
        set_directory(op.dirname(fn_simulated_raw))
        raw.save(fn_simulated_raw, overwrite=True)

    # Plot the simulated raw data in the report
    if fn_report_h5 is not None:
        from matplotlib import pyplot as plt
        set_directory(op.dirname(fn_report_h5))
        fn_report_html = fn_report_h5.rsplit('.h5')[0] + '.html'

        now = datetime.now()
        with mne.open_report(fn_report_h5) as report:
            fig = plt.figure()
            plt.plot(times, generate_signal(times, freq=10))
            plt.xlabel('Time (s)')

            ax = fig.axes[0]
            add_text_next_to_xlabel(fig, ax,
                                    now.strftime('%m/%d/%Y, %H:%M:%S'))

            report.add_figs_to_section(fig,
                                       now.strftime('Signal time course'),
                                       section='Sensor-level',
                                       replace=True)

            fig = raw.plot()

            # axis 1 contains the xlabel
            ax = fig.axes[1]
            add_text_next_to_xlabel(fig, ax,
                                    now.strftime('%m/%d/%Y, %H:%M:%S'))

            report.add_figs_to_section(fig,
                                       now.strftime('Simulated raw'),
                                       section='Sensor-level',
                                       replace=True)
            report.save(fn_report_html, overwrite=True, open_browser=False)

    raw._annotations = mne.annotations.Annotations([], [], [])
    return raw, stc_signal
Ejemplo n.º 10
0
def run():
    t0 = time.time()
    parser = get_optparser(__file__)
    parser.add_option("--raw",
                      dest="raw_in",
                      help="Input raw FIF file",
                      metavar="FILE")
    parser.add_option("--pos",
                      dest="pos",
                      default=None,
                      help="Position definition text file. Can be 'constant' "
                      "to hold the head position fixed",
                      metavar="FILE")
    parser.add_option("--dipoles",
                      dest="dipoles",
                      default=None,
                      help="Dipole definition file",
                      metavar="FILE")
    parser.add_option("--cov",
                      dest="cov",
                      help="Covariance to use for noise generation. Can be "
                      "'simple' to use a diagonal covariance, or 'off' to "
                      "omit noise",
                      metavar="FILE",
                      default='simple')
    parser.add_option("--duration",
                      dest="duration",
                      default=None,
                      help="Duration of each epoch (sec). If omitted, the last"
                      " time point in the dipole definition file plus 200 ms "
                      "will be used",
                      type="float")
    parser.add_option("-j",
                      "--jobs",
                      dest="n_jobs",
                      help="Number of jobs to"
                      " run in parallel",
                      type="int",
                      default=1)
    parser.add_option("--out",
                      dest="raw_out",
                      help="Output raw filename",
                      metavar="FILE")
    parser.add_option("--plot-dipoles",
                      dest="plot_dipoles",
                      help="Plot "
                      "input dipole positions",
                      action="store_true")
    parser.add_option("--plot-raw",
                      dest="plot_raw",
                      help="Plot the resulting "
                      "raw traces",
                      action="store_true")
    parser.add_option("--plot-evoked",
                      dest="plot_evoked",
                      help="Plot evoked "
                      "data",
                      action="store_true")
    parser.add_option("-p",
                      "--plot",
                      dest="plot",
                      help="Plot dipoles, raw, "
                      "and evoked",
                      action="store_true")
    parser.add_option("--overwrite",
                      dest="overwrite",
                      help="Overwrite the"
                      "output file if it exists",
                      action="store_true")
    options, args = parser.parse_args()

    raw_in = options.raw_in
    pos = options.pos
    raw_out = options.raw_out
    dipoles = options.dipoles
    n_jobs = options.n_jobs
    plot = options.plot
    plot_dipoles = options.plot_dipoles or plot
    plot_raw = options.plot_raw or plot
    plot_evoked = options.plot_evoked or plot
    overwrite = options.overwrite
    duration = options.duration
    cov = options.cov

    # check parameters
    if not (raw_out or plot_raw or plot_evoked):
        raise ValueError('data must either be saved (--out) or '
                         'plotted (--plot-raw or --plot_evoked)')
    if raw_out and op.isfile(raw_out) and not overwrite:
        raise ValueError('output file exists, use --overwrite (%s)' % raw_out)

    if raw_in is None or pos is None or dipoles is None:
        parser.print_help()
        sys.exit(1)

    s = 'Simulate raw data with head movements'
    print('\n%s\n%s\n%s\n' % ('-' * len(s), s, '-' * len(s)))

    # setup the simulation

    with printer('Reading dipole definitions'):
        if not op.isfile(dipoles):
            raise IOError('dipole file not found:\n%s' % dipoles)
        dipoles = np.loadtxt(dipoles, skiprows=1, dtype=float)
        n_dipoles = dipoles.shape[0]
        if dipoles.shape[1] != 8:
            raise ValueError('dipoles must have 8 columns')
        rr = dipoles[:, :3] * 1e-3
        nn = dipoles[:, 3:6]
        t = dipoles[:, 6:8]
        duration = t.max() + 0.2 if duration is None else duration
        if (t[:, 0] > t[:, 1]).any():
            raise ValueError('found tmin > tmax in dipole file')
        if (t < 0).any():
            raise ValueError('found t < 0 in dipole file')
        if (t > duration).any():
            raise ValueError('found t > duration in dipole file')
        amp = np.sqrt(np.sum(nn * nn, axis=1)) * 1e-9
        mne.surface._normalize_vectors(nn)
        nn[(nn == 0).all(axis=1)] = (1, 0, 0)
        src = mne.SourceSpaces([
            dict(rr=rr,
                 nn=nn,
                 inuse=np.ones(n_dipoles, int),
                 coord_frame=FIFF.FIFFV_COORD_HEAD)
        ])
        for key in ['pinfo', 'nuse_tri', 'use_tris', 'patch_inds']:
            src[0][key] = None
        trans = {
            'from': FIFF.FIFFV_COORD_HEAD,
            'to': FIFF.FIFFV_COORD_MRI,
            'trans': np.eye(4)
        }
        if (amp > 100e-9).any():
            print('')
            warnings.warn('Largest dipole amplitude %0.1f > 100 nA' %
                          (amp.max() * 1e9))

    if pos == 'constant':
        print('Holding head position constant')
        pos = None
    else:
        with printer('Loading head positions'):
            pos = mne.get_chpi_positions(pos)

    with printer('Loading raw data file'):
        with warnings.catch_warnings(record=True):
            raw = mne.io.Raw(raw_in,
                             preload=False,
                             allow_maxshield=True,
                             verbose=False)

    if cov == 'simple':
        print('Using diagonal covariance for brain noise')
    elif cov == 'off':
        print('Omitting brain noise in the simulation')
        cov = None
    else:
        with printer('Loading covariance file for brain noise'):
            cov = mne.read_cov(cov)

    with printer('Setting up spherical model'):
        bem = mne.bem.make_sphere_model('auto',
                                        'auto',
                                        raw.info,
                                        verbose=False)
        # check that our sources are reasonable
        rad = bem['layers'][0]['rad']
        r0 = bem['r0']
        outside = np.sqrt(np.sum((rr - r0)**2, axis=1)) >= rad
        n_outside = outside.sum()
        if n_outside > 0:
            print('')
            raise ValueError(
                '%s dipole%s outside the spherical model, are your positions '
                'in mm?' % (n_outside, 's were' if n_outside != 1 else ' was'))

    with printer('Constructing source estimate'):
        tmids = t.mean(axis=1)
        t = np.round(t * raw.info['sfreq']).astype(int)
        t[:, 1] += 1  # make it inclusive
        n_samp = int(np.ceil(duration * raw.info['sfreq']))
        data = np.zeros((n_dipoles, n_samp))
        for di, (t_, amp_) in enumerate(zip(t, amp)):
            data[di, t_[0]:t_[1]] = amp_ * np.hanning(t_[1] - t_[0])
        stc = mne.VolSourceEstimate(data, np.arange(n_dipoles), 0,
                                    1. / raw.info['sfreq'])

    # do the simulation
    print('')
    raw_mv = simulate_raw(raw,
                          stc,
                          trans,
                          src,
                          bem,
                          cov=cov,
                          head_pos=pos,
                          chpi=True,
                          n_jobs=n_jobs,
                          verbose=True)
    print('')

    if raw_out:
        with printer('Saving data'):
            raw_mv.save(raw_out, overwrite=overwrite)

    # plot results -- must be *after* save because we low-pass filter
    if plot_dipoles:
        with printer('Plotting dipoles'):
            fig, axs = plt.subplots(1, 3, figsize=(10, 3), facecolor='w')
            fig.canvas.set_window_title('Dipoles')
            meg_info = mne.pick_info(
                raw.info, mne.pick_types(raw.info, meg=True, eeg=False))
            helmet_rr = [
                ch['coil_trans'][:3, 3].copy() for ch in meg_info['chs']
            ]
            helmet_nn = np.zeros_like(helmet_rr)
            helmet_nn[:, 2] = 1.
            surf = dict(rr=helmet_rr,
                        nn=helmet_nn,
                        coord_frame=FIFF.FIFFV_COORD_DEVICE)
            helmet_rr = mne.surface.transform_surface_to(
                surf, 'head', meg_info['dev_head_t'])['rr']
            p = np.linspace(0, 2 * np.pi, 40)
            x_sphere, y_sphere = rad * np.sin(p), rad * np.cos(p)
            for ai, ax in enumerate(axs):
                others = np.setdiff1d(np.arange(3), [ai])
                ax.plot(helmet_rr[:, others[0]],
                        helmet_rr[:, others[1]],
                        marker='o',
                        linestyle='none',
                        alpha=0.1,
                        markeredgecolor='none',
                        markerfacecolor='b',
                        zorder=-2)
                ax.plot(x_sphere + r0[others[0]],
                        y_sphere + r0[others[1]],
                        color='y',
                        alpha=0.25,
                        zorder=-1)
                ax.quiver(rr[:, others[0]],
                          rr[:, others[1]],
                          amp * nn[:, others[0]],
                          amp * nn[:, others[1]],
                          angles='xy',
                          units='x',
                          color='k',
                          alpha=0.5)
                ax.set_aspect('equal')
                ax.set_xlabel(' - ' + 'xyz'[others[0]] + ' + ')
                ax.set_ylabel(' - ' + 'xyz'[others[1]] + ' + ')
                ax.set_xticks([])
                ax.set_yticks([])
                plt.setp(list(ax.spines.values()), color='none')
            plt.tight_layout()

    if plot_raw or plot_evoked:
        with printer('Low-pass filtering simulated data'):
            events = mne.find_events(raw_mv, 'STI101', verbose=False)
            b, a = signal.butter(4,
                                 40. / (raw.info['sfreq'] / 2.),
                                 'low',
                                 analog=False)
            raw_mv.filter(None,
                          40.,
                          method='iir',
                          iir_params=dict(b=b, a=a),
                          verbose=False,
                          n_jobs=n_jobs)
        if plot_raw:
            with printer('Plotting raw data'):
                raw_mv.plot(clipping='transparent', events=events, show=False)
        if plot_evoked:
            with printer('Plotting evoked data'):
                picks = mne.pick_types(raw_mv.info, meg=True, eeg=True)
                events[:, 2] = 1
                evoked = mne.Epochs(raw_mv, events, {
                    'Simulated': 1
                }, 0, duration, None, picks).average()
                evoked.plot_topomap(np.unique(tmids), show=False)

    print('\nTotal time: %0.1f sec' % (time.time() - t0))
    sys.stdout.flush()
    if any([plot_dipoles, plot_raw, plot_evoked]):
        plt.show(block=True)
Ejemplo n.º 11
0
def apply_mft(fwdname,
              datafile,
              evocondition=None,
              meg='mag',
              exclude='bads',
              mftpar=None,
              calccdm=None,
              cdmcut=0.,
              cdmlabels=None,
              subject=None,
              save_stc=True,
              verbose=False):
    """ Apply MFT to specified data set.

    Parameters
    ----------
    fwdname: name of forward solution file
    datafile: name of datafile (ave or raw)
    evocondition: condition in case of evoked input file
    meg: meg-channels to pick ['mag']
    exclude: meg-channels to exclude ['bads']
    mftpar: dictionary with parameters for MFT algorithm
    calccdm : str | None
              where str can be 'all', 'both', 'left', 'right'
    cdmcut : (rel.) cut to use in cdm-calculations [0.]
    cdmlabels: list of labels to analyse
               entries for 'cdmlabels', 'jlglabels', 'jtotlabels'
               in qualdata are returned, containing cdm,
               longitudinal and total current for each label.
    subject : str | None
        The subject name. While not necessary, it is safer to set the
        subject parameter to avoid analysis errors.
    verbose: control variable for verbosity
             False='CRITICAL','WARNING',True='INFO','DEBUG'
             or 'chatty'='verbose' (>INFO,<DEBUG)

    Returns
    qualmft: dictionary with relerr,rdmerr,mag-arrays and
             cdm-arrays (if requested)
    stcdata: stc with ||cdv|| at fwdmag['source_rr']
             (type corresponding to forward solution)
    """
    twgbl0 = time.time()
    tcgbl0 = time.clock()

    # Use mftparm as local copy of mftpar to keep that ro.
    mftparm = {}
    if mftpar:
        mftparm.update(mftpar)
    mftparm.setdefault('iter', 8)
    mftparm.setdefault('currexp', 1)
    mftparm.setdefault('prbfct', 'uniform')
    mftparm.setdefault('prbcnt')
    mftparm.setdefault('prbhw')
    mftparm.setdefault('regtype', 'PzetaE')
    mftparm.setdefault('zetareg', 1.00)
    mftparm.setdefault('solver', 'lu')
    mftparm.setdefault('svrelcut', 5.e-4)

    if mftparm['solver'] == 'svd':
        use_svd = True
        use_lud = False
        svrelcut = mftparm['svrelcut']
    elif mftparm['solver'] == 'lu' or mftparm['solver'] == 'ludecomp':
        use_lud = True
        use_svd = False
    else:
        raise ValueError(
            ">>>>> mftpar['solver'] must be either 'svd' or 'lu[decomp]'")

    if mftparm['prbfct'].lower() == 'gauss':
        if not mftparm['prbcnt'].all() or not mftparm['prbhw'].all():
            raise ValueError(
                ">>>>> 'prbfct'='Gauss' requires 'prbcnt' and 'prbhw' entries")
    elif mftparm['prbfct'].lower() != 'uniform' and mftparm['prbfct'].lower(
    ) != 'flat':
        raise ValueError(">>>>> unrecognized keyword for 'prbfct'")
    if mftparm['prbcnt'] == None and mftparm['prbhw'] == None:
        prbcnt = np.array([0.0, 0.0, 0.0], ndmin=2)
        prbdhw = np.array([0.0, 0.0, 0.0], ndmin=2)
    else:
        prbcnt = np.reshape(mftparm['prbcnt'],
                            (len(mftparm['prbcnt'].flatten()) / 3, 3))
        prbdhw = np.reshape(mftparm['prbhw'],
                            (len(mftparm['prbhw'].flatten()) / 3, 3))
    if prbcnt.shape != prbdhw.shape:
        raise ValueError(
            ">>>>> mftpar['prbcnt'] and mftpar['prbhw'] must have same size")

    verbosity = 1
    if verbose == False or verbose == 'CRITICAL':
        verbosity = -1
    elif verbose == 'WARNING':
        verbosity = 0
    elif verbose == 'chatty' or verbose == 'verbose':
        verbose = 'INFO'
        verbosity = 2
    elif verbose == 'DEBUG':
        verbosity = 3

    if verbosity >= 0:
        print "meg-channels     = ", meg
        print "exclude-channels = ", exclude
        print "mftpar['iter'    ] = ", mftparm['iter']
        print "mftpar['currexp' ] = ", mftparm['currexp']
        print "mftpar['regtype' ] = ", mftparm['regtype']
        print "mftpar['zetareg' ] = ", mftparm['zetareg']
        print "mftpar['solver'  ] = ", mftparm['solver']
        print "mftpar['svrelcut'] = ", mftparm['svrelcut']
        print "mftpar['prbfct'  ] = ", mftparm['prbfct']
        print "mftpar['prbcnt'  ] = ", mftparm['prbcnt']
        print "mftpar['prbhw'   ] = ", mftparm['prbhw']
        if mftparm['prbcnt'] != None or mftparm['prbhw'] != None:
            for icnt in xrange(prbcnt.shape[0]):
                print "  pos(prbcnt[%d])   = " % (icnt + 1), prbcnt[icnt]
                print "  dhw(prbdhw[%d])   = " % (icnt + 1), prbdhw[icnt]
        if calccdm:
            print "calccdm = '%s' with rel. cut = %5.2f" % (calccdm, cdmcut)
    if calccdm and (cdmcut < 0. or cdmcut >= 1.):
        raise ValueError(">>>>> cdmcut must be in [0,1)")

    # Msg will be written by mne.read_forward_solution()
    fwd = mne.read_forward_solution(fwdname, verbose=verbose)
    # Block off fixed_orientation fwd-s for now:
    if fwd['source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI:
        raise ValueError(
            ">>>>> apply_mft() cannot handle fixed-orientation fwd-solutions")

    # Select magnetometer channels:
    fwdmag = mne.io.pick.pick_types_forward(fwd,
                                            meg=meg,
                                            ref_meg=False,
                                            eeg=False,
                                            exclude=exclude)
    lfmag = fwdmag['sol']['data']

    n_sens, n_loc = lfmag.shape
    n_srcspace = len([s['vertno'] for s in fwdmag['src']])
    if verbosity >= 2:
        print "Leadfield size : n_sen x n_loc = %d x %d" % (n_sens, n_loc)
        print "Number of source spaces = %d" % n_srcspace

    if cdmlabels is not None:
        if verbosity >= 1:
            print "########## Searching for label(s) in source space(s)..."
        tc0 = time.clock()
        tw0 = time.time()

    numcdmlabels = 0
    labvrtstot = 0
    labvrtsusd = 0
    if cdmlabels is not None:
        invmri_head_t = mne.transforms.invert_transform(
            fwdmag['info']['mri_head_t'])
        mrsrcpnt = np.zeros(fwdmag['source_rr'].shape)
        mrsrcpnt = mne.transforms.apply_trans(invmri_head_t['trans'],
                                              fwdmag['source_rr'])
        offsets = [0]
        for s in fwdmag['src']:
            offsets = np.append(offsets, [offsets[-1] + s['nuse']])
        labinds = []
        ilab = 0
        for label in cdmlabels:
            ilab = ilab + 1

            labvrts = []
            # Find src corresponding to this label (match by position)
            # (Assume surface-labels are in head-cd, vol-labels in MR-cs)
            isrc = 0
            for s in fwdmag['src']:
                isrc += 1

                labvrts = label.get_vertices_used(vertices=s['vertno'])
                numlabvrts = len(labvrts)
                if numlabvrts == 0:
                    continue
                if not np.all(s['inuse'][labvrts]):
                    print "isrc = %d: label='%s' (np.all(s['inuse'][labvrts])=False)" % (
                        isrc, label.name)
                    continue
                #    labindx: indices of used label-vertices in this src-space + offset2'source_rr'
                # iinlabused: indices of used label-vertices in this label
                labindx = np.searchsorted(s['vertno'],
                                          labvrts) + offsets[isrc - 1]
                iinlabused = np.searchsorted(label.vertices, labvrts)
                if s['type'] == 'surf':
                    if not np.allclose(mrsrcpnt[labindx, :],
                                       label.pos[iinlabused]):
                        continue  # mismatch
                else:
                    if not np.allclose(fwdmag['source_rr'][labindx, :],
                                       label.pos[iinlabused]):
                        continue  # mismatch
                if verbosity >= 1:
                    print "%3d %30s %7s: %5d verts %4d used" % \
                          (ilab, label.name, label.hemi, len(label.vertices), numlabvrts)
                break  # from src-space-loop

            if len(labvrts) > 0:
                labvrtstot += len(label.vertices)
                labvrtsusd += len(labvrts)
                labinds.append(labindx)
                numcdmlabels = len(labinds)
            else:
                warnings.warn(
                    'NO vertex found for label \'%s\' in any source space' %
                    label.name)
        if verbosity >= 1:
            print "--> sums: %5d verts %4d used" % (labvrtstot, labvrtsusd)
            tc1 = time.clock()
            tw1 = time.time()
            print "prep. labels took %.3f" % (
                1000. * (tc1 - tc0)), "ms (%.3f s walltime)" % (tw1 - tw0)

    if datafile.rfind('-ave.fif') > 0 or datafile.rfind('-ave.fif.gz') > 0:
        if verbosity >= 0:
            print "Reading evoked data from %s" % datafile
        if evocondition is None:
            #indatinfo = mne.io.read_info(datafile)
            indathndl = mne.read_evokeds(datafile,
                                         baseline=(None, 0),
                                         verbose=verbose)
            if len(indathndl) > 1:
                raise ValueError(
                    ">>>>> need to specify a condition for this datafile. Aborting-"
                )
            picks = mne.io.pick.pick_types(indathndl[0].info,
                                           meg=meg,
                                           ref_meg=False,
                                           eeg=False,
                                           stim=False,
                                           exclude=exclude)
            data = indathndl[0].data[picks, :]
        else:
            indathndl = mne.read_evokeds(datafile,
                                         condition=evocondition,
                                         baseline=(None, 0),
                                         verbose=verbose)
            #if len(indathndl) > 1:
            #    raise ValueError(">>>>> need to specify a condition for this datafile. Aborting-")
            picks = mne.io.pick.pick_types(indathndl.info,
                                           meg=meg,
                                           ref_meg=False,
                                           eeg=False,
                                           stim=False,
                                           exclude=exclude)
            data = indathndl.data[picks, :]
    elif datafile.rfind('-raw.fif') > 0 or datafile.rfind('-raw.fif.gz') > 0:
        if verbosity >= 0:
            print "Reading raw data from %s" % datafile
        indathndl = mne.io.Raw(datafile, preload=True, verbose=verbose)
        picks = mne.io.pick.pick_types(indathndl.info,
                                       meg=meg,
                                       ref_meg=False,
                                       eeg=False,
                                       stim=False,
                                       exclude=exclude)
        data = indathndl._data[picks, :]
    else:
        raise ValueError(
            ">>>>> datafile is neither 'ave' nor 'raw'. Aborting-")
    if verbosity >= 3:
        print "data.shape = ", data.shape
    if n_sens != data.shape[0]:
        raise ValueError(
            ">>>>> Mismatch in #channels for forward (%d) and data (%d) files. Aborting."
            % (n_sens, data.shape[0]))

    tptotwall = 0.
    tptotcpu = 0.
    nptotcall = 0
    tltotwall = 0.
    tltotcpu = 0.
    nltotcall = 0
    tpcdmwall = 0.
    tpcdmcpu = 0.
    npcdmcall = 0
    if verbosity >= 1:
        print "########## Calculate initial prob-dist:"
    tw0 = time.time()
    tc0 = time.clock()
    if mftpar['prbfct'] == 'Gauss':
        wtmp = np.zeros(n_loc / 3)
        for icnt in xrange(prbcnt.shape[0]):
            testdiff = fwdmag['source_rr'] - prbcnt[icnt, :]
            testdiff = testdiff / prbdhw[icnt, :]
            testdiff = testdiff * testdiff
            testsq = np.sum(testdiff, 1)
            wtmp += np.exp(-testsq)
        wdist0 = wtmp / (np.sum(wtmp) * np.sqrt(3.))
    elif mftpar['prbfct'] == 'flat' or mftpar['prbfct'] == 'uniform':
        if verbosity >= 2:
            print "Setting initial w=const !"
        wdist0 = np.ones(n_loc / 3) / (float(n_loc) / np.sqrt(3.))
    else:
        raise ValueError(
            ">>>>> mftpar['prbfct'] must be 'Gauss' or 'uniform'/'flat'")
    wdist3 = np.repeat(wdist0, 3)
    if verbosity >= 3:
        wvecnorm = np.sum(
            np.sqrt(
                np.sum(np.reshape(wdist3, (wdist3.shape[0] / 3, 3))**2,
                       axis=1)))
        print "sum(||wvec(i)||) = ", wvecnorm
    tc1 = time.clock()
    tw1 = time.time()
    if verbosity >= 1:
        print "calc(wdist0) took %.3f" % (
            1000. * (tc1 - tc0)), "ms (%.3f s walltime)" % (tw1 - tw0)

    if verbosity >= 1:
        print "########## Calculate P-matrix, incl. weights:"
    tw0 = time.time()
    tc0 = time.clock()
    lfw = lfmag * np.repeat(np.sqrt(wdist0), 3)
    pmat0 = np.einsum('ik,jk->ij', lfw, lfw)
    # Avoiding sqrt is expensive!
    # pmat0 = np.einsum('ik, k, jk -> ij', lfmag, wdist3, lfmag)
    tc1 = time.clock()
    tw1 = time.time()
    tptotwall += (tw1 - tw0)
    tptotcpu += (tc1 - tc0)
    nptotcall += 1
    if verbosity >= 1:
        print "calc(lf*w*lf.T) took ", 1000. * (
            tc1 - tc0), "ms (%.3f s walltime)" % (tw1 - tw0)

    # Normalize P:
    pmax = np.amax([np.abs(np.amax(pmat0)), np.abs(np.amin(pmat0))])
    if verbosity >= 3:
        print "pmax(init) = ", pmax
    pscalefct = 1.
    while pmax > 1.0:
        pmax /= 2.
        pscalefct /= 2.
    while pmax < 0.5:
        pmax *= 2.
        pscalefct *= 2.
    #print ">>>>> Keeping scale factor eq 1"
    #pscalefct = 1.
    pmat0 = pmat0 * pscalefct
    if verbosity >= 3:
        print "pmax(fin.) = ", np.amax(
            [np.abs(np.amax(pmat0)),
             np.abs(np.amin(pmat0))])

    # Regularize P:
    if mftparm['regtype'] == 'PzetaE':
        zetatrp = mftparm['zetareg'] * np.trace(pmat0) / float(pmat0.shape[0])
        if verbosity >= 3:
            print "Use PzetaE-regularization with zeta*tr(P)/ncol(P) = %12.5e" % zetatrp
        ptilde0 = pmat0 + zetatrp * np.identity(pmat0.shape[0])
    elif mftparm['regtype'] == 'classic' or mftparm['regtype'] == 'PPzetaP':
        zetatrp = mftparm['zetareg'] * np.trace(pmat0) / float(pmat0.shape[0])
        if verbosity >= 3:
            print "Use PPzetaP-regularization with zeta*tr(P)/ncol(P) = %12.5e" % zetatrp
        ptilde0 = np.dot(pmat0, pmat0) + zetatrp * pmat0
    else:
        raise ValueError(
            ">>>>> mftpar['regtype'] must be 'PzetaE' or 'classic''")

    # decompose:
    if use_lud is True:
        LU0, P0 = scipy.linalg.lu_factor(ptilde0)
        #rhstmp = np.zeros([LU0.shape[1]])
        #xtmp = np.empty([LU0.shape[1]])
        #xtmp = scipy.linalg.lu_solve((LU0,P0),rhstmp)
        if verbosity >= 3:
            # Calculate condition number:
            #(sign, lndetbf) = np.linalg.slogdet(ptilde0)
            lndettr = np.sum(np.log(np.abs(np.diag(LU0))))
            #print "lndet(ptilde0) = %8.3f =?= %8.3f = sum(log(|diag(LU0)|))" % (lndetbf,lndettr)
            # log(prod(a_i, i=1,n)) for a_i = sqrt(sum(ptilde0_ij^2, j=1,n))
            denom = np.sum(np.log(np.sqrt(np.sum(ptilde0 * ptilde0, axis=0))))
            lncondno = lndettr - denom
            print "ln(condno) = %8.3f, K_H = 10^(%8.3f) = %8.3f" % (
                lncondno, lncondno / np.log(10.), np.exp(lncondno))
            print "(K_H < 0.01 : bad, K_H > 0.1 : good)"

    if use_svd is True:
        U, s, V = np.linalg.svd(ptilde0, full_matrices=True)
        if verbosity >= 2:
            print ">>> SV range %e ... %e" % (np.amax(s), np.amin(s))
        dtmp = s.max() * svrelcut
        s *= (abs(s) >= dtmp)
        sinv = [
            1. / s[k] if s[k] != 0. else 0. for k in xrange(ptilde0.shape[0])
        ]
        if verbosity >= 2:
            print ">>> With rel-cutoff=%e   %d out of %d SVs remain" % \
                  (svrelcut,np.array(np.nonzero(sinv)).shape[1],len(sinv))
        if verbosity >= 3:
            stat = np.allclose(ptilde0, np.dot(U, np.dot(np.diag(s), V)))
            print ">>> Testing svd-result: %s" % stat
            if not stat:
                print "    (Maybe due to SV-cutoff?)"
            print ">>> Setting ptildeinv=(U diag(sinv) V).tr"
        ptilde0inv = np.transpose(np.dot(U, np.dot(np.diag(sinv), V)))
        if verbosity >= 3:
            stat = np.allclose(np.identity(ptilde0.shape[0]),
                               np.dot(ptilde0inv, ptilde0))
            if stat:
                print ">>> Testing ptilde0inv-result (shld be unit-matrix): ok"
            else:
                print ">>> Testing ptilde0inv-result (shld be unit-matrix): failed"
                print np.transpose(np.dot(ptilde0inv, ptilde0))
                print ">>>"

    if verbosity >= 1:
        print "########## Create stc data and qual data arrays:"
    qualdata = {
        'relerr': np.zeros(data.shape[1]),
        'rdmerr': np.zeros(data.shape[1]),
        'mag': np.zeros(data.shape[1])
    }
    if calccdm is not None:
        if verbosity >= 0 and \
           n_srcspace ==1 and (calccdm == 'left' or calccdm == 'right'):
            print ">>>Warning>> cdm-results may differ from what you expect."
        ids = data.shape[1]
        if calccdm == 'all':
            (qualdata['cdmall'], qualdata['jlgall']) = (np.zeros(ids),
                                                        np.zeros(ids))
            (qualdata['cdmleft'], qualdata['jlgleft']) = (np.zeros(ids),
                                                          np.zeros(ids))
            (qualdata['cdmright'], qualdata['jlgright']) = (np.zeros(ids),
                                                            np.zeros(ids))
        elif calccdm == 'both':
            (qualdata['cdmleft'], qualdata['jlgleft']) = (np.zeros(ids),
                                                          np.zeros(ids))
            (qualdata['cdmright'], qualdata['jlgright']) = (np.zeros(ids),
                                                            np.zeros(ids))
        elif calccdm == 'left':
            (qualdata['cdmleft'], qualdata['jlgleft']) = (np.zeros(ids),
                                                          np.zeros(ids))
        elif calccdm == 'right':
            (qualdata['cdmright'], qualdata['jlgright']) = (np.zeros(ids),
                                                            np.zeros(ids))
        elif calccdm == 'glob':
            (qualdata['cdmall'], qualdata['jlgall']) = (np.zeros(ids),
                                                        np.zeros(ids))
        if qualdata.has_key('cdmleft'):
            fwdlhinds = np.where(fwdmag['source_rr'][:, 0] < 0.)[0]
        if qualdata.has_key('cdmright'):
            fwdrhinds = np.where(fwdmag['source_rr'][:, 0] > 0.)[0]
    if cdmlabels is not None and numcdmlabels > 0:
        qualdata['cdmlabels'] = np.zeros((numcdmlabels, data.shape[1]))
        qualdata['jlglabels'] = np.zeros((numcdmlabels, data.shape[1]))
        qualdata['jtotlabels'] = np.zeros((numcdmlabels, data.shape[1]))

    stcdata = np.zeros([n_loc / 3, data.shape[1]])

    if verbosity >= 2:
        print "Reading %d slices of data to calc. cdv:" % data.shape[1]
        if data.shape[1] > 1000:
            print " "
    for islice in xrange(data.shape[1]):
        wdist = np.copy(wdist0)
        wdist3 = np.repeat(wdist, 3)
        pmat = np.copy(pmat0)
        ptilde = np.copy(ptilde0)
        if use_svd is True:
            ptildeinv = np.copy(ptilde0inv)
        if use_lud is True:
            LU = np.copy(LU0)
            P = np.copy(P0)

        slice = pscalefct * data[:, islice]
        if mftparm['regtype'] == 'PzetaE':
            mtilde = np.copy(slice)
        else:
            mtilde = np.dot(pmat, slice)

        acoeff = np.empty([ptilde.shape[0]])
        if use_svd is True:
            for irow in xrange(ptilde.shape[0]):
                acoeff[irow] = np.dot(ptildeinv[irow, :], mtilde)
        if use_lud is True:
            acoeff = scipy.linalg.lu_solve((LU, P), mtilde)

        cdv = np.zeros(n_loc)
        cdvnorms = np.zeros(n_loc / 3)
        for krow in xrange(lfmag.shape[0]):
            lfwtmp = lfmag[krow, :] * wdist3
            cdv += acoeff[krow] * lfwtmp

        tlw0 = time.time()
        tlc0 = time.clock()
        for mftiter in xrange(mftparm['iter']):
            # MFT iteration loop:

            cdvecs = np.reshape(cdv, (cdv.shape[0] / 3, 3))
            cdvnorms = np.sqrt(np.sum(cdvecs**2, axis=1))

            wdist = np.power(cdvnorms, mftparm['currexp']) * wdist0
            wdistsum = np.sum(wdist)
            wdist = wdist / wdistsum
            wdist3 = np.repeat(wdist, 3)

            # Calculate new P-matrix, incl. weights:
            tw0 = time.time()
            tc0 = time.clock()
            lfw = lfmag * np.repeat(np.sqrt(pscalefct * wdist), 3)
            pmat = np.einsum('ik,jk->ij', lfw, lfw)

            tc1 = time.clock()
            tw1 = time.time()
            tptotwall += (tw1 - tw0)
            tptotcpu += (tc1 - tc0)
            nptotcall += 1

            # Regularize P:
            if mftparm['regtype'] == 'PzetaE':
                ptilde = pmat + zetatrp * np.identity(pmat.shape[0])
            else:
                ptilde = np.dot(pmat, pmat) + zetatrp * pmat

            # decompose:
            if use_svd is True:
                U, s, V = np.linalg.svd(ptilde, full_matrices=True)
                dtmp = s.max() * svrelcut
                s *= (abs(s) >= dtmp)
                sinv = [
                    1. / s[k] if s[k] != 0. else 0.
                    for k in xrange(ptilde.shape[0])
                ]
                ptildeinv = np.transpose(np.dot(U, np.dot(np.diag(sinv), V)))
                for irow in xrange(ptilde.shape[0]):
                    acoeff[irow] = np.dot(ptildeinv[irow, :], mtilde)
            if use_lud is True:
                LU, P = scipy.linalg.lu_factor(ptilde)
                acoeff = scipy.linalg.lu_solve((LU, P), mtilde)

            cdv = np.einsum('ji,i,j->i', lfmag, wdist3, acoeff)

        tc1 = time.clock()
        tw1 = time.time()
        tltotwall += (tw1 - tlw0)
        tltotcpu += (tc1 - tlc0)
        nltotcall += 1
        cdvecs = np.reshape(cdv, (cdv.shape[0] / 3, 3))
        cdvnorms = np.sqrt(np.sum(cdvecs**2, axis=1))
        #(relerr,rdmerr,mag) = compare_est_exp(ptilde,acoeff,mtilde)
        (relerr, rdmerr, mag) = compare_est_exp(pmat, acoeff, slice)
        qualdata['relerr'][islice] = relerr
        qualdata['rdmerr'][islice] = rdmerr
        qualdata['mag'][islice] = mag

        tc0 = time.clock()
        tw0 = time.time()
        if qualdata.has_key('cdmall'):
            (qualdata['cdmall'][islice],
             qualdata['jlgall'][islice]) = scan_cdm_w_cut(cdv, cdmcut)
        if qualdata.has_key('cdmleft'):
            (qualdata['cdmleft'][islice],qualdata['jlgleft'][islice]) = \
                                                      scan_cdm_w_cut(cdvecs[fwdlhinds,:],cdmcut)
        if qualdata.has_key('cdmright'):
            (qualdata['cdmright'][islice],qualdata['jlgright'][islice]) = \
                                                      scan_cdm_w_cut(cdvecs[fwdrhinds,:],cdmcut)
        if qualdata.has_key('cdmlabels'):
            for ilab in xrange(numcdmlabels):
                (qualdata['cdmlabels'][ilab,islice],qualdata['jlglabels'][ilab,islice]) = \
                               scan_cdm_w_cut(cdvecs[labinds[ilab],:],cdmcut)
                qualdata['jtotlabels'][ilab,islice] = \
                            calc_jtotal_w_cut(cdvecs[labinds[ilab],:],cdmcut)
        tc1 = time.clock()
        tw1 = time.time()
        tpcdmwall += (tw1 - tw0)
        tpcdmcpu += (tc1 - tc0)
        npcdmcall += 1

        # Write final cdv to file:
        for iloc in xrange(n_loc / 3):
            stcdata[iloc, islice] = cdvnorms[iloc]
        del wdist
        if verbosity >= 2 and islice > 0 and islice % 1000 == 0:
            print "\r%6d out of %6d slices done." % (islice, data.shape[1])
    if verbosity >= 2 and data.shape[1] > 1000:
        print "Done."

    vertices = [s['vertno'] for s in fwdmag['src']]
    tstep = 1. / indathndl.info['sfreq']
    tmin = indathndl.times[0]
    if len(vertices) == 1:
        stc_mft = mne.VolSourceEstimate(stcdata,
                                        vertices=fwdmag['src'][0]['vertno'],
                                        tmin=tmin,
                                        tstep=tstep,
                                        subject=subject)
    elif len(vertices) == 2:
        vertices = [s['vertno'] for s in fwdmag['src']]
        stc_mft = mne.SourceEstimate(stcdata,
                                     vertices=vertices,
                                     tmin=tmin,
                                     tstep=tstep,
                                     subject=subject)
    else:
        vertices = np.concatenate(([s['vertno'] for s in fwdmag['src']]))
        stc_mft = mne.VolSourceEstimate(stcdata,
                                        vertices=vertices,
                                        tmin=tmin,
                                        tstep=tstep,
                                        subject=subject)

    stcdatamft = stc_mft.data
    print "##### Results:"
    for islice in xrange(data.shape[1]):
        print "slice=%4d: relerr=%9.3e rdmerr=%9.3e mag=%5.3f cdvmax=%9.2e" % \
              (islice, qualdata['relerr'][islice], qualdata['rdmerr'][islice], qualdata['mag'][islice],\
               np.amax(stcdatamft[:, islice]))
    stat = np.allclose(stcdata, stcdatamft, atol=0., rtol=1e-07)
    if stat:
        print "stcdata from mft-SR and old calc agree."
    else:
        print "stcdata from mft-SR and old calc DIFFER."

    if save_stc:
        # save Surface stc.
        print "##### Trying to save stc:"
        stcmft_fname = os.path.join(
            os.path.dirname(datafile),
            os.path.basename(datafile).split('-')[0]) + "mft"
        print "stcmft basefilename: %s" % stcmft_fname
        stc_mft.save(stcmft_fname, verbose=True)
        print "##### done."

    twgbl1 = time.time()
    tcgbl1 = time.clock()
    if verbosity >= 1:
        print "calc(lf*w*lf.T) took   total  %9.2f s CPU-time (%9.2f s walltime)" % (
            tptotcpu, tptotwall)
        print "calc(lf*w*lf.T) took per call %9.2fms CPU-time (%9.2fms walltime)" % \
                               (1000.*tptotcpu/float(nptotcall),1000.*tptotwall/float(nptotcall))
        print "scan_cdm calls  took   total  %9.2f s CPU-time (%9.2f s walltime)" % (
            tpcdmcpu, tpcdmwall)
        print "scan_cdm calls  took per call %9.2fms CPU-time (%9.2fms walltime)" % \
                               (1000.*tpcdmcpu/float(npcdmcall),1000.*tpcdmwall/float(npcdmcall))
        print "iteration-loops took   total  %9.2f s CPU-time (%9.2f s walltime)" % (
            tltotcpu, tltotwall)
        print "iteration-loops took per call %9.2fms CPU-time (%9.2fms walltime)" % \
                               (1000.*tltotcpu/float(nltotcall),1000.*tltotwall/float(nltotcall))
        print "Total mft-call  took   total  %9.2f s CPU-time (%9.2f s walltime)" % (
            (tcgbl1 - tcgbl0), (twgbl1 - twgbl0))
    return (fwdmag, qualdata, stc_mft)
# normalize to unit length
tangential = (tangential.T * (1. / np.linalg.norm(tangential, axis=1))).T

nn = tangential

###############################################################################
# Simulate a single signal dipole source as signal
###############################################################################

signal_vertex = src[0]['vertno'][config.vertex]
data = np.asarray([generate_signal(times, freq=config.signal_freq)])
vertices = np.array([signal_vertex])
stc_signal = mne.VolSourceEstimate(data=data,
                                   vertices=vertices,
                                   tmin=0,
                                   tstep=1 / info['sfreq'],
                                   subject='sample')

stc_signal.save(vfname.stc_signal(noise=config.noise, vertex=config.vertex))

###############################################################################
# Create discrete source space based on voxels in volume source space
###############################################################################

if not op.exists(vfname.fwd_discrete):

    pos = {'rr': rr, 'nn': nn}

    # make discrete source space
    src_disc = mne.setup_volume_source_space(subject='sample',
Ejemplo n.º 13
0
def rsa_source_level(stcs,
                     dsm_model,
                     src,
                     spatial_radius=0.04,
                     temporal_radius=0.1,
                     stc_dsm_metric='correlation',
                     stc_dsm_params=dict(),
                     rsa_metric='spearman',
                     y=None,
                     n_folds=1,
                     sel_vertices=None,
                     tmin=None,
                     tmax=None,
                     n_jobs=1,
                     verbose=False):
    """Perform RSA in a searchlight pattern across the source space.

    The output is a source estimate where the "signal" at each source point is
    the RSA, computed for a patch surrounding the source point.

    Parameters
    ----------
    stcs : list of mne.SourceEstimate | list of mne.VolSourceEstimate
        For each item, a source estimate for the brain activity.
    dsm_model : ndarray, shape (n, n) | (n * (n - 1) // 2,) | list of ndarray
        The model DSM, see :func:`compute_dsm`. For efficiency, you can give it
        in condensed form, meaning only the upper triangle of the matrix as a
        vector. See :func:`scipy.spatial.distance.squareform`. To perform RSA
        against multiple models at the same time, supply a list of model DSMs.

        Use :func:`compute_dsm` to compute DSMs.
    src : instance of mne.SourceSpaces
        The source space used by the source estimates specified in the `stcs`
        parameter.
    spatial_radius : float | None
        The spatial radius of the searchlight patch in meters. All source
        points within this radius will belong to the searchlight patch. Set to
        None to only perform the searchlight over time, flattening across
        sensors. Defaults to 0.04.
    temporal_radius : float | None
        The temporal radius of the searchlight patch in seconds. Set to None to
        only perform the searchlight over sensors, flattening across time.
        Defaults to 0.1.
    stc_dsm_metric : str
        The metric to use to compute the DSM for the source estimates. This can
        be any metric supported by the scipy.distance.pdist function. See also
        the ``stc_dsm_params`` parameter to specify and additional parameter
        for the distance function. Defaults to 'correlation'.
    stc_dsm_params : dict
        Extra arguments for the distance metric used to compute the DSMs.
        Refer to :mod:`scipy.spatial.distance` for a list of all other metrics
        and their arguments. Defaults to an empty dictionary.
    rsa_metric : str
        The RSA metric to use to compare the DSMs. Valid options are:

        * 'spearman' for Spearman's correlation (the default)
        * 'pearson' for Pearson's correlation
        * 'kendall-tau-a' for Kendall's Tau (alpha variant)
        * 'partial' for partial Pearson correlations
        * 'partial-spearman' for partial Spearman correlations
        * 'regression' for linear regression weights

        Defaults to 'spearman'.
    y : ndarray of int, shape (n_items,) | None
        For each source estimate, a number indicating the item to which it
        belongs. When ``None``, each source estimate is assumed to belong to a
        different item. Defaults to ``None``.
    n_folds : int | None
        Number of folds to use when using cross-validation to compute the
        evoked DSM metric. Specify ``None``, to use the maximum number of folds
        possible, given the data.
        Defaults to 1 (no cross-validation).
    sel_vertices : list of int | None
        When set, searchlight patches will only be generated for the subset of
        vertices/voxels with the given indices. Defaults to ``None``, in which
        case patches for all vertices/voxels are generated.
    tmin : float | None
        When set, searchlight patches will only be generated from subsequent
        time points starting from this time point. This value is given in
        seconds. Defaults to ``None``, in which case patches are generated
        starting from the first time point.
    tmax : float | None
        When set, searchlight patches will only be generated up to and
        including this time point. This value is given in seconds. Defaults to
        ``None``, in which case patches are generated up to and including the
        last time point.
    n_jobs : int
        The number of processes (=number of CPU cores) to use. Specify -1 to
        use all available cores. Defaults to 1.
    verbose : bool
        Whether to display a progress bar. In order for this to work, you need
        the tqdm python module installed. Defaults to False.

    Returns
    -------
    stc : SourceEstimate | VolSourceEstimate | list of SourceEstimate | list of VolSourceEstimate
        The correlation values for each searchlight patch. When spatial_radius
        is set to None, there will only be one vertex. When temporal_radius is
        set to None, there will only be one time point. When multiple models
        have been supplied, a list will be returned containing the RSA results
        for each model.

    See Also
    --------
    compute_dsm
    """  # noqa E501
    # Check for compatibility of the source estimated and the model features
    one_model = type(dsm_model) is np.ndarray
    if one_model:
        dsm_model = [dsm_model]

    # Check for compatibility of the evokeds and the model features
    for dsm in dsm_model:
        n_items = _n_items_from_dsm(dsm)
        if len(stcs) != n_items and y is None:
            raise ValueError(
                'The number of source estimates (%d) should be equal to the '
                'number of items in `dsm_model` (%d). Alternatively, use '
                'the `y` parameter to assign evokeds to items.' %
                (len(stcs), n_items))
        if y is not None and len(np.unique(y)) != n_items:
            raise ValueError(
                'The number of items in `dsm_model` (%d) does not match '
                'the number of items encoded in the `y` matrix (%d).' %
                (n_items, len(np.unique(y))))

    _check_compatible(stcs, src)
    dist = _get_distance_matrix(src, dist_lim=spatial_radius, n_jobs=n_jobs)

    if temporal_radius is not None:
        # Convert the temporal radius to samples
        temporal_radius = int(temporal_radius // stcs[0].tstep)

        if temporal_radius < 1:
            raise ValueError('Temporal radius is less than one sample.')

    sel_samples = _tmin_tmax_to_indices(stcs[0].times, tmin, tmax)

    # Perform the RSA
    X = np.array([stc.data for stc in stcs])
    patches = searchlight(X.shape,
                          dist=dist,
                          spatial_radius=spatial_radius,
                          temporal_radius=temporal_radius,
                          sel_series=sel_vertices,
                          sel_samples=sel_samples)
    data = rsa_array(X,
                     dsm_model,
                     patches,
                     data_dsm_metric=stc_dsm_metric,
                     data_dsm_params=stc_dsm_params,
                     rsa_metric=rsa_metric,
                     y=y,
                     n_folds=n_folds,
                     n_jobs=n_jobs,
                     verbose=verbose)

    # Pack the result in a SourceEstimate object
    if spatial_radius is not None:
        vertices = stcs[0].vertices
    else:
        if src.kind == 'volume':
            vertices = [np.array([1])]
        else:
            vertices = [np.array([1]), np.array([])]
    tmin = _construct_tmin(stcs[0].times, sel_samples, temporal_radius)
    tstep = stcs[0].tstep

    if one_model:
        if src.kind == 'volume':
            return mne.VolSourceEstimate(data,
                                         vertices,
                                         tmin,
                                         tstep,
                                         subject=stcs[0].subject)
        else:
            return mne.SourceEstimate(data,
                                      vertices,
                                      tmin,
                                      tstep,
                                      subject=stcs[0].subject)
    else:
        if src.kind == 'volume':
            return [
                mne.VolSourceEstimate(data[:, :, i],
                                      vertices,
                                      tmin,
                                      tstep,
                                      subject=stcs[0].subject)
                for i in range(data.shape[-1])
            ]
        else:
            return [
                mne.SourceEstimate(data[:, :, i],
                                   vertices,
                                   tmin,
                                   tstep,
                                   subject=stcs[0].subject)
                for i in range(data.shape[-1])
            ]
Ejemplo n.º 14
0
def apply_mft(fwdname,
              datafile,
              evocondition=None,
              meg='mag',
              exclude='bads',
              mftpar=None,
              subject=None,
              save_stc=True,
              verbose=False):
    """ Apply MFT to specified data set.

    Parameters
    ----------
    fwdname: name of forward solution file
    datafile: name of datafile (ave or raw)
    evocondition: condition in case of evoked input file
    meg: meg-channels to pick ['mag']
    exclude: meg-channels to exclude ['bads']
    mftpar: dictionary with parameters for MFT algorithm
    subject : str | None
        The subject name. While not necessary, it is safer to set the
        subject parameter to avoid analysis errors.
    verbose: control variable for verbosity
             False='CRITICAL','WARNING',True='INFO','DEBUG'
             or 'chatty'='verbose' (>INFO,<DEBUG)

    Returns
    -------
    qualmft: dictionary with relerr,rdmerr,mag-arrays
    stcdata: stc with ||cdv|| at fwdmag['source_rr']
             (type corresponding to forward solution)
    """
    twgbl0 = time.time()
    tcgbl0 = time.clock()

    if mftpar is None:
        mftpar = {
            'prbfct': 'uniform',
            'prbcnt': None,
            'prbhw': None,
            'iter': 8,
            'currexp': 1,
            'regtype': 'PzetaE',
            'zetareg': 1.00,
            'solver': 'lu',
            'svrelcut': 5.e-4
        }

    if mftpar['solver'] == 'svd':
        use_svd = True
        use_lud = False
        svrelcut = mftpar['svrelcut']
    elif mftpar['solver'] == 'lu' or mftpar['solver'] == 'ludecomp':
        use_lud = True
        use_svd = False
    else:
        raise ValueError(
            ">>>>> mftpar['solver'] must be either 'svd' or 'lu[decomp]'")

    if mftpar['prbcnt'] == None and mftpar['prbhw'] == None:
        prbcnt = np.array([0.0, 0.0, 0.0], ndmin=2)
        prbdhw = np.array([0.0, 0.0, 0.0], ndmin=2)
    else:
        prbcnt = np.reshape(mftpar['prbcnt'],
                            (len(mftpar['prbcnt'].flatten()) / 3, 3))
        prbdhw = np.reshape(mftpar['prbhw'],
                            (len(mftpar['prbhw'].flatten()) / 3, 3))
    if prbcnt.shape != prbdhw.shape:
        raise ValueError(
            ">>>>> mftpar['prbcnt'] and mftpar['prbhw'] must have same size")

    verbosity = 1
    if verbose == False or verbose == 'CRITICAL':
        verbosity = -1
    elif verbose == 'WARNING':
        verbosity = 0
    elif verbose == 'chatty' or verbose == 'verbose':
        verbose = 'INFO'
        verbosity = 2
    elif verbose == 'DEBUG':
        verbosity = 3

    if verbosity >= 0:
        print "meg-channels     = ", meg
        print "exclude-channels = ", exclude
        print "mftpar['iter'    ] = ", mftpar['iter']
        print "mftpar['regtype' ] = ", mftpar['regtype']
        print "mftpar['zetareg' ] = ", mftpar['zetareg']
        print "mftpar['solver'  ] = ", mftpar['solver']
        print "mftpar['svrelcut'] = ", mftpar['svrelcut']
        print "mftpar['prbfct'  ] = ", mftpar['prbfct']
        print "mftpar['prbcnt'  ] = ", mftpar['prbcnt']
        print "mftpar['prbhw'   ] = ", mftpar['prbhw']
        if mftpar['prbcnt'] != None or mftpar['prbhw'] != None:
            for icnt in xrange(prbcnt.shape[0]):
                print "  pos(prbcnt[%d])   = " % (icnt + 1), prbcnt[icnt]
                print "  dhw(prbdhw[%d])   = " % (icnt + 1), prbdhw[icnt]

    # Msg will be written by mne.read_forward_solution()
    fwd = mne.read_forward_solution(fwdname, verbose=verbose)
    # Block off fixed_orientation fwd-s for now:
    if fwd['source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI:
        raise ValueError(
            ">>>>> apply_mft() cannot handle fixed-orientation fwd-solutions")

    # Select magnetometer channels:
    fwdmag = mne.io.pick.pick_types_forward(fwd,
                                            meg=meg,
                                            ref_meg=False,
                                            eeg=False,
                                            exclude=exclude)
    lfmag = fwdmag['sol']['data']

    n_sens, n_loc = lfmag.shape
    if verbosity >= 2:
        print "Leadfield size : n_sen x n_loc = %d x %d" % (n_sens, n_loc)

    if datafile.rfind('-ave.fif') > 0 or datafile.rfind('-ave.fif.gz') > 0:
        if verbosity >= 0:
            print "Reading evoked data from %s" % datafile
        if evocondition is None:
            #indatinfo = mne.io.read_info(datafile)
            indathndl = mne.read_evokeds(datafile,
                                         baseline=(None, 0),
                                         verbose=verbose)
            if len(indathndl) > 1:
                raise ValueError(
                    ">>>>> need to specify a condition for this datafile. Aborting-"
                )
            picks = mne.io.pick.pick_types(indathndl[0].info,
                                           meg=meg,
                                           ref_meg=False,
                                           eeg=False,
                                           stim=False,
                                           exclude=exclude)
            data = indathndl[0].data[picks, :]
        else:
            indathndl = mne.read_evokeds(datafile,
                                         condition=evocondition,
                                         baseline=(None, 0),
                                         verbose=verbose)
            #if len(indathndl) > 1:
            #    raise ValueError(">>>>> need to specify a condition for this datafile. Aborting-")
            picks = mne.io.pick.pick_types(indathndl.info,
                                           meg=meg,
                                           ref_meg=False,
                                           eeg=False,
                                           stim=False,
                                           exclude=exclude)
            data = indathndl.data[picks, :]
    elif datafile.rfind('-raw.fif') > 0 or datafile.rfind('-raw.fif.gz') > 0:
        if verbosity >= 0:
            print "Reading raw data from %s" % datafile
        indathndl = mne.io.Raw(datafile, preload=True, verbose=verbose)
        picks = mne.io.pick.pick_types(indathndl.info,
                                       meg=meg,
                                       ref_meg=False,
                                       eeg=False,
                                       stim=False,
                                       exclude=exclude)
        data = indathndl._data[picks, :]
    else:
        raise ValueError(
            ">>>>> datafile is neither 'ave' nor 'raw'. Aborting-")
    if verbosity >= 3:
        print "data.shape = ", data.shape
    if n_sens != data.shape[0]:
        raise ValueError(
            ">>>>> Mismatch in #channels for forward (%d) and data (%d) files. Aborting."
            % (n_sens, data.shape[0]))

    tptotwall = 0.
    tptotcpu = 0.
    nptotcall = 0
    tltotwall = 0.
    tltotcpu = 0.
    nltotcall = 0
    tpcdmwall = 0.
    tpcdmcpu = 0.
    npcdmcall = 0
    if verbosity >= 1:
        print "########## Calculate initial prob-dist:"
    tw0 = time.time()
    tc0 = time.clock()
    if mftpar['prbfct'] == 'Gauss':
        wtmp = np.zeros(n_loc / 3)
        for icnt in xrange(prbcnt.shape[0]):
            testdiff = fwdmag['source_rr'] - prbcnt[icnt, :]
            testdiff = testdiff / prbdhw[icnt, :]
            testdiff = testdiff * testdiff
            testsq = np.sum(testdiff, 1)
            wtmp += np.exp(-testsq)
        wdist0 = wtmp / (np.sum(wtmp) * np.sqrt(3.))
    elif mftpar['prbfct'] == 'flat' or mftpar['prbfct'] == 'uniform':
        if verbosity >= 2:
            print "Setting w=const !"
        wdist0 = np.ones(n_loc / 3) / (float(n_loc) / np.sqrt(3.))
    else:
        raise ValueError(
            ">>>>> mftpar['prbfct'] must be 'Gauss' or 'uniform'/'flat'")
    wdist3 = np.repeat(wdist0, 3)
    if verbosity >= 3:
        wvecnorm = np.sum(
            np.sqrt(
                np.sum(np.reshape(wdist3, (wdist3.shape[0] / 3, 3))**2,
                       axis=1)))
        print "sum(||wvec(i)||) = ", wvecnorm
    tc1 = time.clock()
    tw1 = time.time()
    if verbosity >= 1:
        print "calc(wdist0) took %.3f" % (
            1000. * (tc1 - tc0)), "ms (%.3f s walltime)" % (tw1 - tw0)

    if verbosity >= 1:
        print "########## Calculate P-matrix, incl. weights:"
    tw0 = time.time()
    tc0 = time.clock()
    wdist3rt = np.repeat(np.sqrt(wdist0), 3)
    lfw = lfmag * wdist3rt
    pmat0 = np.einsum('ik,jk->ij', lfw, lfw)
    tc1 = time.clock()
    tw1 = time.time()
    tptotwall += (tw1 - tw0)
    tptotcpu += (tc1 - tc0)
    nptotcall += 1
    if verbosity >= 1:
        print "calc(lf*w*lf.T) took ", 1000. * (
            tc1 - tc0), "ms (%.3f s walltime)" % (tw1 - tw0)

    # Normalize P:
    pmax = np.amax([np.abs(np.amax(pmat0)), np.abs(np.amin(pmat0))])
    if verbosity >= 3:
        print "pmax(init) = ", pmax
    pscalefct = 1.
    while pmax > 1.0:
        pmax /= 2.
        pscalefct /= 2.
    while pmax < 0.5:
        pmax *= 2.
        pscalefct *= 2.
    #print ">>>>> Keeping scale factor eq 1"
    #pscalefct = 1.
    pmat0 = pmat0 * pscalefct
    if verbosity >= 3:
        print "pmax(fin.) = ", np.amax(
            [np.abs(np.amax(pmat0)),
             np.abs(np.amin(pmat0))])

    # Regularize P:
    if mftpar['regtype'] == 'PzetaE':
        zetatrp = mftpar['zetareg'] * np.trace(pmat0) / float(pmat0.shape[0])
        if verbosity >= 3:
            print "Use PzetaE-regularization with zeta*tr(P)/ncol(P) = %12.5e" % zetatrp
        ptilde0 = pmat0 + zetatrp * np.identity(pmat0.shape[0])
    elif mftpar['regtype'] == 'classic' or mftpar['regtype'] == 'PPzetaP':
        zetatrp = mftpar['zetareg'] * np.trace(pmat0) / float(pmat0.shape[0])
        if verbosity >= 3:
            print "Use PPzetaP-regularization with zeta*tr(P)/ncol(P) = %12.5e" % zetatrp
        ptilde0 = np.dot(pmat0, pmat0) + zetatrp * pmat0
    else:
        raise ValueError(
            ">>>>> mftpar['regtype'] must be 'PzetaE' or 'classic''")

    # decompose:
    if use_lud is True:
        LU0, P0 = scipy.linalg.lu_factor(ptilde0)
        #rhstmp = np.zeros([LU0.shape[1]])
        #xtmp = np.empty([LU0.shape[1]])
        #xtmp = scipy.linalg.lu_solve((LU0,P0),rhstmp)
        if verbosity >= 3:
            # Calculate condition number:
            #(sign, lndetbf) = np.linalg.slogdet(ptilde0)
            lndettr = np.sum(np.log(np.abs(np.diag(LU0))))
            #print "lndet(ptilde0) = %8.3f =?= %8.3f = sum(log(|diag(LU0)|))" % (lndetbf,lndettr)
            # log(prod(a_i, i=1,n)) for a_i = sqrt(sum(ptilde0_ij^2, j=1,n))
            denom = np.sum(np.log(np.sqrt(np.sum(ptilde0 * ptilde0, axis=0))))
            lncondno = lndettr - denom
            print "ln(condno) = %8.3f, K_H = 10^(%8.3f) = %8.3f" % (
                lncondno, lncondno / np.log(10.), np.exp(lncondno))
            print "(K_H < 0.01 : bad, K_H > 0.1 : good)"

    if use_svd is True:
        U, s, V = np.linalg.svd(ptilde0, full_matrices=True)
        dtmp = s.max() * svrelcut
        s *= (abs(s) >= dtmp)
        sinv = [
            1. / s[k] if s[k] != 0. else 0. for k in xrange(ptilde0.shape[0])
        ]
        if verbosity >= 2:
            print ">>> With rel-cutoff=%e   %d out of %d SVs remain" % \
                  (svrelcut,np.array(np.nonzero(sinv)).shape[1],len(sinv))
        if verbosity >= 3:
            stat = np.allclose(ptilde0, np.dot(U, np.dot(np.diag(s), V)))
            print ">>> Testing svd-result: %s" % stat
            if not stat:
                print "    (Maybe due to SV-cutoff?)"
            print ">>> Setting ptildeinv=(U diag(sinv) V).tr"
        ptilde0inv = np.transpose(np.dot(U, np.dot(np.diag(sinv), V)))
        if verbosity >= 3:
            stat = np.allclose(np.identity(ptilde0.shape[0]),
                               np.dot(ptilde0inv, ptilde0))
            if stat:
                print ">>> Testing ptilde0inv-result (shld be unit-matrix): ok"
            else:
                print ">>> Testing ptilde0inv-result (shld be unit-matrix): failed"
                print np.transpose(np.dot(ptilde0inv, ptilde0))
                print ">>>"

    if verbosity >= 1:
        print "########## Create stc data and qual data arrays:"
    qualdata = {
        'relerr': np.zeros(data.shape[1]),
        'rdmerr': np.zeros(data.shape[1]),
        'mag': np.zeros(data.shape[1])
    }
    stcdata = np.zeros([n_loc / 3, data.shape[1]])
    stcdata1 = [np.zeros([s['nuse'], data.shape[1]]) for s in fwdmag['src']]
    stcinds = np.zeros((n_loc / 3, 2), dtype=int)
    stcinds1 = np.zeros((n_loc / 3), dtype=int)
    offsets = np.append([0], [s['nuse'] for s in fwdmag['src']])
    iblck = -1
    nmatch = 0
    for s in fwdmag['src']:
        iblck = iblck + 1
        for kvert0 in xrange(s['nuse']):
            kvert1 = offsets[iblck] + kvert0
            if np.all(
                    np.equal(fwdmag['source_rr'][kvert1],
                             s['rr'][s['vertno'][kvert0]])):
                stcinds[kvert1][0] = iblck
                stcinds[kvert1][1] = kvert0
                nmatch = nmatch + 1
    if verbosity >= 3:
        print "Found %d matches in creating source_rr/rr index table." % nmatch

    if verbosity >= 2:
        print "Reading slices of data to calc. cdv:"
        if data.shape[1] > 1000:
            print " "
    for islice in xrange(data.shape[1]):
        wdist = np.copy(wdist0)
        wdist3 = np.repeat(wdist, 3)
        pmat = np.copy(pmat0)
        ptilde = np.copy(ptilde0)
        if use_svd is True:
            ptildeinv = np.copy(ptilde0inv)
        if use_lud is True:
            LU = np.copy(LU0)
            P = np.copy(P0)

        slice = pscalefct * data[:, islice]
        if mftpar['regtype'] == 'PzetaE':
            mtilde = np.copy(slice)
        else:
            mtilde = np.dot(pmat, slice)

        acoeff = np.empty([ptilde.shape[0]])
        if use_svd is True:
            for irow in xrange(ptilde.shape[0]):
                acoeff[irow] = np.dot(ptildeinv[irow, :], mtilde)
        if use_lud is True:
            acoeff = scipy.linalg.lu_solve((LU, P), mtilde)

        cdv = np.zeros(n_loc)
        cdvnorms = np.zeros(n_loc / 3)
        for krow in xrange(lfmag.shape[0]):
            lfwtmp = lfmag[krow, :] * wdist3
            cdv += acoeff[krow] * lfwtmp

        tlw0 = time.time()
        tlc0 = time.clock()
        for mftiter in xrange(mftpar['iter']):
            # MFT iteration loop:

            cdvecs = np.reshape(cdv, (cdv.shape[0] / 3, 3))
            cdvnorms = np.sqrt(np.sum(cdvecs**2, axis=1))

            wdist = np.power(cdvnorms, mftpar['currexp']) * wdist0
            wdistsum = np.sum(wdist)
            wdist = wdist / wdistsum
            wdist3 = np.repeat(wdist, 3)

            # Calculate new P-matrix, incl. weights:
            tw0 = time.time()
            tc0 = time.clock()
            wdist3rt = np.repeat(np.sqrt(pscalefct * wdist), 3)
            lfw = lfmag * wdist3rt
            pmat = np.einsum('ik,jk->ij', lfw, lfw)

            tc1 = time.clock()
            tw1 = time.time()
            tptotwall += (tw1 - tw0)
            tptotcpu += (tc1 - tc0)
            nptotcall += 1

            # Regularize P:
            if mftpar['regtype'] == 'PzetaE':
                ptilde = pmat + zetatrp * np.identity(pmat.shape[0])
            else:
                ptilde = np.dot(pmat, pmat) + zetatrp * pmat

            # decompose:
            if use_svd is True:
                U, s, V = np.linalg.svd(ptilde, full_matrices=True)
                dtmp = s.max() * svrelcut
                s *= (abs(s) >= dtmp)
                sinv = [
                    1. / s[k] if s[k] != 0. else 0.
                    for k in xrange(ptilde.shape[0])
                ]
                ptildeinv = np.transpose(np.dot(U, np.dot(np.diag(sinv), V)))
                for irow in xrange(ptilde.shape[0]):
                    acoeff[irow] = np.dot(ptildeinv[irow, :], mtilde)
            if use_lud is True:
                LU, P = scipy.linalg.lu_factor(ptilde)
                acoeff = scipy.linalg.lu_solve((LU, P), mtilde)

            cdv = np.einsum('ji,i,j->i', lfmag, wdist3, acoeff)

        tc1 = time.clock()
        tw1 = time.time()
        tltotwall += (tw1 - tlw0)
        tltotcpu += (tc1 - tlc0)
        nltotcall += 1
        cdvecs = np.reshape(cdv, (cdv.shape[0] / 3, 3))
        cdvnorms = np.sqrt(np.sum(cdvecs**2, axis=1))
        # (relerr,rdmerr,mag) = compare_est_exp(ptilde,acoeff,mtilde)
        (relerr, rdmerr, mag) = compare_est_exp(pmat, acoeff, slice)
        qualdata['relerr'][islice] = relerr
        qualdata['rdmerr'][islice] = rdmerr
        qualdata['mag'][islice] = mag

        tc0 = time.clock()
        tw0 = time.time()
        tc1 = time.clock()
        tw1 = time.time()
        tpcdmwall += (tw1 - tw0)
        tpcdmcpu += (tc1 - tc0)
        npcdmcall += 1

        # Write final cdv to file:
        for iloc in xrange(n_loc / 3):
            #stcdata1[stcinds[iloc][0]][stcinds[iloc][1],islice] = cdvnorms[iloc]
            stcdata[iloc, islice] = cdvnorms[iloc]
        del wdist
        if verbosity >= 2 and islice > 0 and islice % 1000 == 0:
            print "\r%6d out of %6d slices done." % (islice, data.shape[1])
    if verbosity >= 2 and data.shape[1] > 1000:
        print "Done."

    vertices = [s['vertno'] for s in fwdmag['src']]
    tstep = 1. / indathndl.info['sfreq']
    tmin = indathndl.times[0]
    if len(vertices) == 1:
        stc_mft = mne.VolSourceEstimate(stcdata,
                                        vertices=fwdmag['src'][0]['vertno'],
                                        tmin=tmin,
                                        tstep=tstep,
                                        subject=subject)
    else:
        vertices = [s['vertno'] for s in fwdmag['src']]
        stc_mft = mne.SourceEstimate(stcdata,
                                     vertices=vertices,
                                     tmin=tmin,
                                     tstep=tstep,
                                     subject=subject)

    stcdatamft = stc_mft.data
    print "##### Results:"
    for islice in xrange(data.shape[1]):
        print "slice=%4d: relerr=%9.3e rdmerr=%9.3e mag=%5.3f cdvmax=%9.2e" % \
              (islice, qualdata['relerr'][islice], qualdata['rdmerr'][islice], qualdata['mag'][islice],\
               np.amax(stcdatamft[:, islice]))
        #     (islice,qualdata[0,islice],qualdata[1,islice],qualdata[2,islice], \
        #     np.amax([np.amax(sb[:,islice]) for sb in stcdatamft]))
    stat = np.allclose(stcdata, stcdatamft, atol=0., rtol=1e-07)
    if stat:
        print "stcdata from mft-SR and old calc agree."
    else:
        print "stcdata from mft-SR and old calc DIFFER."

    if save_stc:
        # save Surface stc.
        print "##### Trying to save stc:"
        stcmft_fname = os.path.join(
            os.path.dirname(datafile),
            os.path.basename(datafile).split('-')[0]) + "mft"
        print "stcmft basefilename: %s" % stcmft_fname
        stc_mft.save(stcmft_fname, verbose=True)
        print "##### done."

    write_tab_files = True
    if write_tab_files:
        time_idx = np.argmax(np.max(stcdata, axis=0))
        tabfilenam = 'testtab.dat'
        print "##### Creating %s with |cdv(time_idx=%d)|" % (tabfilenam,
                                                             time_idx)
        tabfile = open(tabfilenam, mode='w')
        cdvnmax = np.max(stcdata[:, time_idx])
        tabfile.write("# time_idx = %d\n" % time_idx)
        tabfile.write("# max amplitude = %11.4e\n" % cdvnmax)
        tabfile.write("#  x/mm    y/mm    z/mm     |cdv|   index\n")
        for ipnt in xrange(n_loc / 3):
            copnt = 1000. * fwdmag['source_rr'][ipnt]
            tabfile.write(" %7.2f %7.2f %7.2f %11.4e %5d\n" % \
                          (copnt[0], copnt[1], copnt[2], stcdata[ipnt, time_idx], ipnt))
        tabfile.close()

        tabfilenam = 'testwtab.dat'
        print "##### Creating %s with wdist0" % tabfilenam
        tabfile = open(tabfilenam, mode='w')
        tabfile.write("# time_idx = %d\n" % time_idx)
        for icnt in xrange(prbcnt.shape[0]):
            cocnt = 1000. * prbcnt[icnt, :]
            tabfile.write("# center  %7.2f %7.2f %7.2f\n" %
                          (cocnt[0], cocnt[1], cocnt[2]))

        tabfile.write("# max value = %11.4e\n" % np.max(wdist0))
        tabfile.write("#  x/mm    y/mm    z/mm    wdist0   index")
        for icnt in xrange(prbcnt.shape[0]):
            tabfile.write("  d_%d/mm" % (icnt + 1))
        tabfile.write("\n")
        for ipnt in xrange(n_loc / 3):
            copnt = 1000. * fwdmag['source_rr'][ipnt]
            tabfile.write(" %7.2f %7.2f %7.2f %11.4e %5d" %\
                          (copnt[0],copnt[1],copnt[2],wdist0[ipnt],ipnt))
            for icnt in xrange(prbcnt.shape[0]):
                cocnt = 1000. * prbcnt[icnt, :]
                dist = np.sqrt(np.dot((copnt - cocnt), (copnt - cocnt)))
                tabfile.write("  %7.2f" % dist)
            tabfile.write("\n")
        tabfile.close()

    twgbl1 = time.time()
    tcgbl1 = time.clock()
    if verbosity >= 1:
        print "calc(lf*w*lf.T) took   total  %9.2f s CPU-time (%9.2f s walltime)" % (
            tptotcpu, tptotwall)
        print "calc(lf*w*lf.T) took per call %9.2fms CPU-time (%9.2fms walltime)" % \
                               (1000.*tptotcpu/float(nptotcall),1000.*tptotwall/float(nptotcall))
        print "iteration-loops took   total  %9.2f s CPU-time (%9.2f s walltime)" % (
            tltotcpu, tltotwall)
        print "iteration-loops took per call %9.2fms CPU-time (%9.2fms walltime)" % \
                               (1000.*tltotcpu/float(nltotcall),1000.*tltotwall/float(nltotcall))
        print "Total mft-call  took   total  %9.2f s CPU-time (%9.2f s walltime)" % (
            (tcgbl1 - tcgbl0), (twgbl1 - twgbl0))

    return (fwdmag, qualdata, stc_mft)
src = mne.setup_volume_source_space(
    pos=7.5,
    subject='spm',
    bem=bem,
    subjects_dir='/home/robbis/mne_data/MNE-spm-face/subjects/')

voldata = np.zeros((src[0]['nuse'], data.shape[0]))

vertno = src[0]['vertno']
vertpos = src[0]['rr'][vertno]

dist = cdist(vertpos, inside)
minvert = np.argmin(dist, axis=0)
voldata[minvert] = data.T

stc = mne.VolSourceEstimate(voldata, [src[0]['vertno']], 0, 1, subject='spm')
brain = stc.plot(
    src,
    colormap='gnuplot',
    subjects_dir='/home/robbis/mne_data/MNE-spm-face/subjects/',
    mode='glass_brain',
    clim=dict(kind='percent', lims=[75, 85, 95]),
    initial_pos=np.array([-36, -25, 60]) / 1000.,
)

morph = mne.compute_source_morph(src,
                                 subject_from=None,
                                 subject_to='fsaverage',
                                 subjects_dir=None,
                                 zooms='auto',
                                 niter_affine=(100, 100, 10),