Beispiel #1
0
def test_convert_forward():
    """Test converting forward solution between different representations
    """
    fwd = read_forward_solution(fname_meeg_grad)
    fwd_repr = repr(fwd)
    assert_true('306' in fwd_repr)
    assert_true('60' in fwd_repr)
    assert_true(fwd_repr)
    assert_true(isinstance(fwd, Forward))
    # look at surface orientation
    fwd_surf = convert_forward_solution(fwd, surf_ori=True)

    # The following test can be removed in 0.16
    fwd_surf_io = read_forward_solution(fname_meeg_grad, surf_ori=True)
    compare_forwards(fwd_surf, fwd_surf_io)
    del fwd_surf_io
    gc.collect()

    # go back
    fwd_new = convert_forward_solution(fwd_surf, surf_ori=False)
    assert_true(repr(fwd_new))
    assert_true(isinstance(fwd_new, Forward))
    compare_forwards(fwd, fwd_new)
    del fwd_new
    gc.collect()

    # now go to fixed
    fwd_fixed = convert_forward_solution(fwd_surf, surf_ori=True,
                                         force_fixed=True, use_cps=False)
    del fwd_surf
    gc.collect()
    assert_true(repr(fwd_fixed))
    assert_true(isinstance(fwd_fixed, Forward))
    assert_true(is_fixed_orient(fwd_fixed))

    # The following test can be removed in 0.16
    fwd_fixed_io = read_forward_solution(fname_meeg_grad, force_fixed=True)
    assert_true(repr(fwd_fixed_io))
    assert_true(isinstance(fwd_fixed_io, Forward))
    assert_true(is_fixed_orient(fwd_fixed_io))
    compare_forwards(fwd_fixed, fwd_fixed_io)
    del fwd_fixed_io
    gc.collect()

    # now go back to cartesian (original condition)
    fwd_new = convert_forward_solution(fwd_fixed, surf_ori=False,
                                       force_fixed=False)
    assert_true(repr(fwd_new))
    assert_true(isinstance(fwd_new, Forward))
    compare_forwards(fwd, fwd_new)
    del fwd, fwd_new, fwd_fixed
    gc.collect()
def test_convert_forward():
    """Test converting forward solution between different representations."""
    fwd = read_forward_solution(fname_meeg_grad)
    fwd_repr = repr(fwd)
    assert ('306' in fwd_repr)
    assert ('60' in fwd_repr)
    assert (fwd_repr)
    assert (isinstance(fwd, Forward))
    # look at surface orientation
    fwd_surf = convert_forward_solution(fwd, surf_ori=True)
    # go back
    fwd_new = convert_forward_solution(fwd_surf, surf_ori=False)
    assert (repr(fwd_new))
    assert (isinstance(fwd_new, Forward))
    assert_forward_allclose(fwd, fwd_new)
    del fwd_new
    gc.collect()

    # now go to fixed
    fwd_fixed = convert_forward_solution(fwd_surf, surf_ori=True,
                                         force_fixed=True, use_cps=False)
    del fwd_surf
    gc.collect()
    assert (repr(fwd_fixed))
    assert (isinstance(fwd_fixed, Forward))
    assert (is_fixed_orient(fwd_fixed))
    # now go back to cartesian (original condition)
    fwd_new = convert_forward_solution(fwd_fixed, surf_ori=False,
                                       force_fixed=False)
    assert (repr(fwd_new))
    assert (isinstance(fwd_new, Forward))
    assert_forward_allclose(fwd, fwd_new)
    del fwd, fwd_new, fwd_fixed
    gc.collect()
def test_priors():
    """Test prior computations."""
    # Depth prior
    fwd = read_forward_solution(fname_meeg)
    assert not is_fixed_orient(fwd)
    n_sources = fwd['nsource']
    info = read_info(fname_evoked)
    depth_prior = compute_depth_prior(fwd, info, exp=0.8)
    assert depth_prior.shape == (3 * n_sources,)
    depth_prior = compute_depth_prior(fwd, info, exp=0.)
    assert_array_equal(depth_prior, 1.)
    with pytest.raises(ValueError, match='must be "whiten"'):
        compute_depth_prior(fwd, info, limit_depth_chs='foo')
    with pytest.raises(ValueError, match='noise_cov must be a Covariance'):
        compute_depth_prior(fwd, info, limit_depth_chs='whiten')
    fwd_fixed = convert_forward_solution(fwd, force_fixed=True)
    depth_prior = compute_depth_prior(fwd_fixed, info=info)
    assert depth_prior.shape == (n_sources,)
    # Orientation prior
    orient_prior = compute_orient_prior(fwd, 1.)
    assert_array_equal(orient_prior, 1.)
    orient_prior = compute_orient_prior(fwd_fixed, 0.)
    assert_array_equal(orient_prior, 1.)
    with pytest.raises(ValueError, match='oriented in surface coordinates'):
        compute_orient_prior(fwd, 0.5)
    fwd_surf_ori = convert_forward_solution(fwd, surf_ori=True)
    orient_prior = compute_orient_prior(fwd_surf_ori, 0.5)
    assert all(np.in1d(orient_prior, (0.5, 1.)))
    with pytest.raises(ValueError, match='between 0 and 1'):
        compute_orient_prior(fwd_surf_ori, -0.5)
    with pytest.raises(ValueError, match='with fixed orientation'):
        compute_orient_prior(fwd_fixed, 0.5)
Beispiel #4
0
def test_localization_bias_loose(bias_params_fixed, method, lower, upper,
                                 depth, loose, pick_ori):
    """Test inverse localization bias for loose minimum-norm solvers."""
    if pick_ori == 'vector' and method == 'eLORETA':  # works, but save cycles
        return
    evoked, fwd, noise_cov, _, want = bias_params_fixed
    fwd = convert_forward_solution(fwd, surf_ori=False, force_fixed=False)
    assert not is_fixed_orient(fwd)
    inv_loose = make_inverse_operator(evoked.info, fwd, noise_cov, loose=loose,
                                      depth=depth)
    loc = apply_inverse(
        evoked, inv_loose, lambda2, method, pick_ori=pick_ori)
    if pick_ori is not None:
        assert loc.data.ndim == 3
        loc, directions = loc.project('pca', src=fwd['src'])
        abs_cos_sim = np.abs(np.sum(
            directions * inv_loose['source_nn'][2::3], axis=1))
        assert np.percentile(abs_cos_sim, 10) > 0.9  # most very aligned
        loc = abs(loc).data
    else:
        loc = loc.data
    assert (loc >= 0).all()
    # Compute the percentage of sources for which there is no loc bias:
    perc = (want == np.argmax(loc, axis=0)).mean() * 100
    assert lower <= perc <= upper, method
Beispiel #5
0
def test_priors():
    """Test prior computations."""
    # Depth prior
    fwd = read_forward_solution(fname_meeg)
    assert not is_fixed_orient(fwd)
    n_sources = fwd['nsource']
    info = read_info(fname_evoked)
    depth_prior = compute_depth_prior(fwd, info, exp=0.8)
    assert depth_prior.shape == (3 * n_sources,)
    depth_prior = compute_depth_prior(fwd, info, exp=0.)
    assert_array_equal(depth_prior, 1.)
    with pytest.raises(ValueError, match='must be "whiten"'):
        compute_depth_prior(fwd, info, limit_depth_chs='foo')
    with pytest.raises(ValueError, match='noise_cov must be a Covariance'):
        compute_depth_prior(fwd, info, limit_depth_chs='whiten')
    fwd_fixed = convert_forward_solution(fwd, force_fixed=True)
    with pytest.deprecated_call():
        depth_prior = compute_depth_prior(
            fwd_fixed['sol']['data'], info, is_fixed_ori=True)
    assert depth_prior.shape == (n_sources,)
    # Orientation prior
    orient_prior = compute_orient_prior(fwd, 1.)
    assert_array_equal(orient_prior, 1.)
    orient_prior = compute_orient_prior(fwd_fixed, 0.)
    assert_array_equal(orient_prior, 1.)
    with pytest.raises(ValueError, match='oriented in surface coordinates'):
        compute_orient_prior(fwd, 0.5)
    fwd_surf_ori = convert_forward_solution(fwd, surf_ori=True)
    orient_prior = compute_orient_prior(fwd_surf_ori, 0.5)
    assert all(np.in1d(orient_prior, (0.5, 1.)))
    with pytest.raises(ValueError, match='between 0 and 1'):
        compute_orient_prior(fwd_surf_ori, -0.5)
    with pytest.raises(ValueError, match='with fixed orientation'):
        compute_orient_prior(fwd_fixed, 0.5)
Beispiel #6
0
def test_localization_bias_loose(bias_params_fixed, method, lower, upper,
                                 depth):
    """Test inverse localization bias for loose minimum-norm solvers."""
    evoked, fwd, noise_cov, _, want = bias_params_fixed
    fwd = convert_forward_solution(fwd, surf_ori=False, force_fixed=False)
    assert not is_fixed_orient(fwd)
    inv_loose = make_inverse_operator(evoked.info, fwd, noise_cov, loose=0.2,
                                      depth=depth)
    loc = apply_inverse(evoked, inv_loose, lambda2, method).data
    assert (loc >= 0).all()
    # Compute the percentage of sources for which there is no loc bias:
    perc = (want == np.argmax(loc, axis=0)).mean() * 100
    assert lower <= perc <= upper, method
Beispiel #7
0
def test_localization_bias_loose(bias_params_fixed, method, lower, upper,
                                 depth):
    """Test inverse localization bias for loose minimum-norm solvers."""
    evoked, fwd, noise_cov, _, want = bias_params_fixed
    fwd = convert_forward_solution(fwd, surf_ori=False, force_fixed=False)
    assert not is_fixed_orient(fwd)
    inv_loose = make_inverse_operator(evoked.info, fwd, noise_cov, loose=0.2,
                                      depth=depth)
    loc = apply_inverse(evoked, inv_loose, lambda2, method).data
    assert (loc >= 0).all()
    # Compute the percentage of sources for which there is no loc bias:
    perc = (want == np.argmax(loc, axis=0)).mean() * 100
    assert lower <= perc <= upper, method
def test_io_forward():
    """Test IO for forward solutions
    """
    temp_dir = _TempDir()
    # do extensive tests with MEEG + grad
    n_channels, n_src = 366, 108
    fwd = read_forward_solution(fname_meeg_grad)
    assert_true(isinstance(fwd, Forward))
    fwd = read_forward_solution(fname_meeg_grad)
    fwd = convert_forward_solution(fwd, surf_ori=True)
    leadfield = fwd['sol']['data']
    assert_equal(leadfield.shape, (n_channels, n_src))
    assert_equal(len(fwd['sol']['row_names']), n_channels)
    fname_temp = op.join(temp_dir, 'test-fwd.fif')
    with warnings.catch_warnings(record=True):
        warnings.simplefilter('always')
        write_forward_solution(fname_temp, fwd, overwrite=True)

    fwd = read_forward_solution(fname_meeg_grad)
    fwd = convert_forward_solution(fwd, surf_ori=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read, surf_ori=True)
    leadfield = fwd_read['sol']['data']
    assert_equal(leadfield.shape, (n_channels, n_src))
    assert_equal(len(fwd_read['sol']['row_names']), n_channels)
    assert_equal(len(fwd_read['info']['chs']), n_channels)
    assert_true('dev_head_t' in fwd_read['info'])
    assert_true('mri_head_t' in fwd_read)
    assert_array_almost_equal(fwd['sol']['data'], fwd_read['sol']['data'])

    fwd = read_forward_solution(fname_meeg)
    fwd = convert_forward_solution(fwd,
                                   surf_ori=True,
                                   force_fixed=True,
                                   use_cps=False)
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read,
                                        surf_ori=True,
                                        force_fixed=True,
                                        use_cps=False)
    assert_true(repr(fwd_read))
    assert_true(isinstance(fwd_read, Forward))
    assert_true(is_fixed_orient(fwd_read))
    compare_forwards(fwd, fwd_read)

    fwd = convert_forward_solution(fwd,
                                   surf_ori=True,
                                   force_fixed=True,
                                   use_cps=True)
    leadfield = fwd['sol']['data']
    assert_equal(leadfield.shape, (n_channels, 1494 / 3))
    assert_equal(len(fwd['sol']['row_names']), n_channels)
    assert_equal(len(fwd['info']['chs']), n_channels)
    assert_true('dev_head_t' in fwd['info'])
    assert_true('mri_head_t' in fwd)
    assert_true(fwd['surf_ori'])
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read,
                                        surf_ori=True,
                                        force_fixed=True,
                                        use_cps=True)
    assert_true(repr(fwd_read))
    assert_true(isinstance(fwd_read, Forward))
    assert_true(is_fixed_orient(fwd_read))
    compare_forwards(fwd, fwd_read)

    fwd = read_forward_solution(fname_meeg_grad)
    fwd = convert_forward_solution(fwd,
                                   surf_ori=True,
                                   force_fixed=True,
                                   use_cps=True)
    leadfield = fwd['sol']['data']
    assert_equal(leadfield.shape, (n_channels, n_src / 3))
    assert_equal(len(fwd['sol']['row_names']), n_channels)
    assert_equal(len(fwd['info']['chs']), n_channels)
    assert_true('dev_head_t' in fwd['info'])
    assert_true('mri_head_t' in fwd)
    assert_true(fwd['surf_ori'])
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read,
                                        surf_ori=True,
                                        force_fixed=True,
                                        use_cps=True)
    assert_true(repr(fwd_read))
    assert_true(isinstance(fwd_read, Forward))
    assert_true(is_fixed_orient(fwd_read))
    compare_forwards(fwd, fwd_read)

    # test warnings on bad filenames
    fwd = read_forward_solution(fname_meeg_grad)
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        fwd_badname = op.join(temp_dir, 'test-bad-name.fif.gz')
        write_forward_solution(fwd_badname, fwd)
        read_forward_solution(fwd_badname)
    assert_naming(w, 'test_forward.py', 2)

    fwd = read_forward_solution(fname_meeg)
    write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    compare_forwards(fwd, fwd_read)
def test_io_forward(tmpdir):
    """Test IO for forward solutions."""
    # do extensive tests with MEEG + grad
    n_channels, n_src = 366, 108
    fwd = read_forward_solution(fname_meeg_grad)
    assert (isinstance(fwd, Forward))
    fwd = read_forward_solution(fname_meeg_grad)
    fwd = convert_forward_solution(fwd, surf_ori=True)
    leadfield = fwd['sol']['data']
    assert_equal(leadfield.shape, (n_channels, n_src))
    assert_equal(len(fwd['sol']['row_names']), n_channels)
    fname_temp = tmpdir.join('test-fwd.fif')
    with pytest.warns(RuntimeWarning, match='stored on disk'):
        write_forward_solution(fname_temp, fwd, overwrite=True)

    fwd = read_forward_solution(fname_meeg_grad)
    fwd = convert_forward_solution(fwd, surf_ori=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read, surf_ori=True)
    leadfield = fwd_read['sol']['data']
    assert_equal(leadfield.shape, (n_channels, n_src))
    assert_equal(len(fwd_read['sol']['row_names']), n_channels)
    assert_equal(len(fwd_read['info']['chs']), n_channels)
    assert ('dev_head_t' in fwd_read['info'])
    assert ('mri_head_t' in fwd_read)
    assert_array_almost_equal(fwd['sol']['data'], fwd_read['sol']['data'])

    fwd = read_forward_solution(fname_meeg)
    fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                   use_cps=False)
    with pytest.warns(RuntimeWarning, match='stored on disk'):
        write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read, surf_ori=True,
                                        force_fixed=True, use_cps=False)
    assert (repr(fwd_read))
    assert (isinstance(fwd_read, Forward))
    assert (is_fixed_orient(fwd_read))
    assert_forward_allclose(fwd, fwd_read)

    fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                   use_cps=True)
    leadfield = fwd['sol']['data']
    assert_equal(leadfield.shape, (n_channels, 1494 / 3))
    assert_equal(len(fwd['sol']['row_names']), n_channels)
    assert_equal(len(fwd['info']['chs']), n_channels)
    assert ('dev_head_t' in fwd['info'])
    assert ('mri_head_t' in fwd)
    assert (fwd['surf_ori'])
    with pytest.warns(RuntimeWarning, match='stored on disk'):
        write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read, surf_ori=True,
                                        force_fixed=True, use_cps=True)
    assert (repr(fwd_read))
    assert (isinstance(fwd_read, Forward))
    assert (is_fixed_orient(fwd_read))
    assert_forward_allclose(fwd, fwd_read)

    fwd = read_forward_solution(fname_meeg_grad)
    fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                   use_cps=True)
    leadfield = fwd['sol']['data']
    assert_equal(leadfield.shape, (n_channels, n_src / 3))
    assert_equal(len(fwd['sol']['row_names']), n_channels)
    assert_equal(len(fwd['info']['chs']), n_channels)
    assert ('dev_head_t' in fwd['info'])
    assert ('mri_head_t' in fwd)
    assert (fwd['surf_ori'])
    with pytest.warns(RuntimeWarning, match='stored on disk'):
        write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read, surf_ori=True,
                                        force_fixed=True, use_cps=True)
    assert (repr(fwd_read))
    assert (isinstance(fwd_read, Forward))
    assert (is_fixed_orient(fwd_read))
    assert_forward_allclose(fwd, fwd_read)

    # test warnings on bad filenames
    fwd = read_forward_solution(fname_meeg_grad)
    fwd_badname = tmpdir.join('test-bad-name.fif.gz')
    with pytest.warns(RuntimeWarning, match='end with'):
        write_forward_solution(fwd_badname, fwd)
    with pytest.warns(RuntimeWarning, match='end with'):
        read_forward_solution(fwd_badname)

    fwd = read_forward_solution(fname_meeg)
    write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    assert_forward_allclose(fwd, fwd_read)
Beispiel #10
0
    ind1 = fwd['src'][1]['inuse']
    # positions of dipoles
    rr = np.vstack(
        [fwd['src'][0]['rr'][ind0 == 1, :], fwd['src'][1]['rr'][ind1 == 1, :]])
    rr = rr / np.max(np.sum(rr**2, axis=1))
    nn = np.vstack(
        [fwd['src'][0]['nn'][ind0 == 1, :], fwd['src'][1]['nn'][ind1 == 1, :]])
    # number of dipoles
    #m = rr.shape[0]
    all_ch_names = evoked.ch_names
    sel = [
        l for l in range(len(all_ch_names))
        if all_ch_names[l] not in evoked.info['bads']
    ]
    print "fixed orient"
    print is_fixed_orient(fwd)
    if not is_fixed_orient(fwd):
        _to_fixed_ori(fwd)
    print "difference of G"
    print np.max(np.abs(fwd['sol']['data'])) / np.min(
        np.abs(fwd['sol']['data']))

    if whiten_flag:
        pca = True
        G, G_info, whitener, source_weighting, mask = _prepare_gain(
            fwd,
            evoked.info,
            noise_cov,
            pca=pca,
            depth=depth,
            loose=None,
Beispiel #11
0
def get_estimate(filepath,
                 outname,
                 method="ROIcov",
                 loose=None,
                 depth=0.8,
                 verbose=True,
                 whiten_flag=True):
    """
    filepath, full path of the file, with no ".mat" or "-ave.fif" suffix.
    e.g. filepath = "/home/ying/sandbox/MEG_simu_test"
    """

    #========= load data ======================
    mat_dict = scipy.io.loadmat(filepath + ".mat")
    ROI_list = list()
    n_ROI = len(mat_dict['ROI_list'][0])
    n_ROI_valid = n_ROI - 1

    for i in range(n_ROI):
        ROI_list.append(mat_dict['ROI_list'][0, i][0])

    M = mat_dict['M']
    QUcov = mat_dict['QUcov']
    Tcov = mat_dict['Tcov']
    Sigma_J_list = mat_dict['Sigma_J_list'][0]
    L_list = list()
    for i in range(n_ROI_valid):
        L_list.append(mat_dict['L_list'][0, i][0])

    fwd_path = mat_dict['fwd_path'][0]
    noise_cov_path = mat_dict['noise_cov_path'][0]
    Sigma_E = mat_dict['Sigma_E']
    G = mat_dict['G']

    # this function returns a list, take the first element
    evoked = mne.read_evokeds(filepath + "-ave.fif.gz")[0]
    # depth weighting has no effect, because this is force_fixed
    fwd = mne.read_forward_solution(fwd_path, force_fixed=True, surf_ori=True)
    noise_cov = mne.read_cov(noise_cov_path)

    # orientation of dipoles
    ind0 = fwd['src'][0]['inuse']
    ind1 = fwd['src'][1]['inuse']
    # positions of dipoles
    rr = np.vstack(
        [fwd['src'][0]['rr'][ind0 == 1, :], fwd['src'][1]['rr'][ind1 == 1, :]])
    rr = rr / np.max(np.sum(rr**2, axis=1))
    nn = np.vstack(
        [fwd['src'][0]['nn'][ind0 == 1, :], fwd['src'][1]['nn'][ind1 == 1, :]])

    # for some methods, only analyze a snapshot
    T = M.shape[-1]
    q = M.shape[0]
    t_ind = T // 2 + 1
    m = rr.shape[0]

    all_ch_names = evoked.ch_names
    sel = [
        l for l in range(len(all_ch_names))
        if all_ch_names[l] not in evoked.info['bads']
    ]

    #======== ROI cov method
    if method is "ROIcov" or "ROIcovKronecker":
        if loose is None and not is_fixed_orient(fwd):
            # follow the tf_mixed_norm
            fwd = copy.deepcopy(fwd)
            # it seems that this will result in different G from reading the forward with force_fixed = True
            _to_fixed_ori(fwd)

        if whiten_flag:
            Sigma_E_chol = np.linalg.cholesky(Sigma_E)
            Sigma_E_chol_inv = np.linalg.inv(Sigma_E_chol)
            G = np.dot(Sigma_E_chol_inv, G)
            # after whitening, the noise cov is assumed to identity
            #Sigma_E = (np.dot(Sigma_E_chol_inv, Sigma_E)).dot(Sigma_E_chol_inv.T)
            Sigma_E = np.eye(G.shape[0])
            M = (np.dot(Sigma_E_chol_inv, M)).transpose([1, 0, 2])

        # these arguments are going to be passed from inputs
        prior_Q, prior_Sigma_J, prior_L, prior_Tcov = False, False, True, False
        Q_flag, Sigma_J_flag, L_flag, Tcov_flag = True, True, True, True
        tau, step_ini, MaxIter, tol, MaxIter0, tol0, verbose0 = 0.8, 1.0, 10, 1E-5, 10, 1E-3, False
        # Create prior for L, not necessarily the same as the truth
        L_list_param = 1.5  # a exp (-b ||x-y||^2)
        Q_L_list = list()
        for i in range(n_ROI_valid):
            tmp_n = len(ROI_list[i])
            tmp = np.zeros([tmp_n, tmp_n])
            for i0 in range(tmp_n):
                for i1 in range(tmp_n):
                    tmp[i0, i1] = np.dot(nn[i0, :], nn[i1, :]) * np.exp(
                        -L_list_param * (np.sum((rr[i0, :] - rr[i1, :])**2)))
            #print np.linalg.cond(tmp)
            Q_L_list.append(tmp)
        inv_Q_L_list = copy.deepcopy(Q_L_list)
        for i in range(n_ROI_valid):
            inv_Q_L_list[i] = np.linalg.inv(Q_L_list[i])

        # other priors, not used for now
        alpha, beta = 1.0, 1.0
        nu = n_ROI_valid + 1
        V_inv = np.eye(n_ROI_valid) * 1E-4
        eps = 1E-13

        V = np.eye(n_ROI_valid)
        nu1 = T + 1
        V1 = np.eye(T)

        # how many times to randomly initialize the data
        n_ini = 1
        result_all = np.zeros(n_ini, dtype=np.object)

        # Tcov initialization to be added
        Tcov0, _ = get_mle_kron_cov(M, tol=1E-6, MaxIter=100)
        T00 = np.linalg.cholesky(Tcov0)

        for l in range(n_ini):
            Qu0 = np.eye(n_ROI_valid) * 1E-18
            sigma_J_list0 = np.ones(n_ROI) * 1E-18
            Sigma_J_list0 = sigma_J_list0**2
            L_list0 = copy.deepcopy(L_list)
            for i in range(n_ROI_valid):
                #L_list0[i] = np.random.randn(L_list0[i].size)
                L_list0[i] = np.ones(L_list0[i].size)

            if method is "ROIcov":
                # just analyze the middle time point.
                MMT = M[:, :, t_ind].T.dot(M[:, :, t_ind])
                if verbose:
                    print "initial obj"
                    Phi0 = np.linalg.cholesky(Qu0)
                    sigma_J_list0 = np.sqrt(Sigma_J_list0)
                    obj0 = get_neg_llh(Phi0, sigma_J_list0, L_list0, ROI_list,
                                       G, MMT, q, Sigma_E, nu, V_inv,
                                       inv_Q_L_list, alpha, beta, prior_Q,
                                       prior_Sigma_J, prior_L, eps)
                    print obj0
                    print "optimial obj"
                    Phi = np.linalg.cholesky(QUcov)
                    # lower case indicates the square root
                    sigma_J_list = np.sqrt(Sigma_J_list)
                    obj_star = get_neg_llh(Phi, sigma_J_list, L_list, ROI_list,
                                           G, MMT, q, Sigma_E, nu, V_inv,
                                           inv_Q_L_list, alpha, beta, prior_Q,
                                           prior_Sigma_J, prior_L, eps)
                    print obj_star
                Qu_hat, Sigma_J_list_hat, L_list_hat, obj = get_map_coor_descent(
                    Qu0,
                    Sigma_J_list0,
                    L_list0,
                    ROI_list,
                    G,
                    MMT,
                    q,
                    Sigma_E,
                    nu,
                    V_inv,
                    inv_Q_L_list,
                    alpha,
                    beta,
                    prior_Q,
                    prior_Sigma_J,
                    prior_L,
                    Q_flag=Q_flag,
                    Sigma_J_flag=Sigma_J_flag,
                    L_flag=L_flag,
                    tau=tau,
                    step_ini=step_ini,
                    MaxIter=MaxIter,
                    tol=tol,
                    eps=eps,
                    verbose=verbose,
                    verbose0=verbose0,
                    MaxIter0=MaxIter0,
                    tol0=tol0)

                diag0 = np.sqrt(np.diag(Qu_hat))
                denom = np.outer(diag0, diag0)
                corr_hat = np.abs(Qu_hat / denom)

                result_all[l] = dict(obj=obj,
                                     Qu_hat=Qu_hat,
                                     Sigma_J_list_hat=Sigma_J_list_hat,
                                     L_list_hat=L_list_hat,
                                     corr_hat=corr_hat,
                                     Tcov_hat=0.0)
            elif method is "ROIcovKronecker":
                if verbose:
                    print "initial obj"
                    obj0 = get_neg_llh_kron(
                        Phi0,
                        sigma_J_list0,
                        L_list0,
                        T00,  # unknown parameters
                        ROI_list,
                        G,
                        M,
                        q,
                        nu,
                        V,
                        nu1,
                        V1,
                        inv_Q_L_list,
                        alpha,
                        beta,  # prior params
                        prior_Q,
                        prior_Sigma_J,
                        prior_L,
                        prior_Tcov)  # prior flags
                    print obj0
                    print "optimial obj"
                    Phi = np.linalg.cholesky(QUcov)
                    sigma_J_list = np.sqrt(Sigma_J_list)
                    T0 = np.linalg.cholesky(Tcov)
                    obj_star = get_neg_llh_kron(
                        Phi,
                        sigma_J_list,
                        L_list,
                        T0,  # unknown parameters
                        ROI_list,
                        G,
                        M,
                        q,
                        nu,
                        V,
                        nu1,
                        V1,
                        inv_Q_L_list,
                        alpha,
                        beta,  # prior params
                        prior_Q,
                        prior_Sigma_J,
                        prior_L,
                        prior_Tcov)  # prior flags
                    print obj_star
                Qu_hat, Sigma_J_list_hat, L_list_hat, Tcov_hat, obj = get_map_coor_descent_kron(
                    Qu0,
                    Sigma_J_list0,
                    L_list0,
                    Tcov0,  # unknown parameters
                    ROI_list,
                    G,
                    M,
                    q,
                    nu,
                    V,
                    nu1,
                    V1,
                    inv_Q_L_list,
                    alpha,
                    beta,  # prior params
                    prior_Q,
                    prior_Sigma_J,
                    prior_L,
                    prior_Tcov,  # prior flags
                    Q_flag=Q_flag,
                    Sigma_J_flag=Sigma_J_flag,
                    L_flag=L_flag,
                    Tcov_flag=Tcov_flag,
                    tau=tau,
                    step_ini=step_ini,
                    MaxIter=MaxIter,
                    tol=tol,
                    verbose=verbose,  # optimization params
                    MaxIter0=MaxIter0,
                    tol0=tol0,
                    verbose0=verbose0)
                diag0 = np.sqrt(np.diag(Qu_hat))
                denom = np.outer(diag0, diag0)
                corr_hat = np.abs(Qu_hat / denom)
                result_all[l] = dict(obj=obj,
                                     Qu_hat=Qu_hat,
                                     Tcov_hat=Tcov_hat,
                                     Sigma_J_list_hat=Sigma_J_list_hat,
                                     L_list_hat=L_list_hat,
                                     corr_hat=corr_hat,
                                     method=method,
                                     lambda2=0.0)
        # choose the best results
        obj_list = np.zeros(n_ini)
        for l in range(n_ini):
            obj_list[l] = result_all[l]['obj']
        result = result_all[np.argmin(obj_list)]

    # can do dSPM too,  directly apply the kernel
    elif method in [
            "mneFlip", "mneTrueL", "mnePairwise", "mneTrueLKronecker",
            "mneFlipKronecker"
    ]:
        mne_method = "dSPM"  # can be MNE
        lambda2 = 1.0

        q = M.shape[0]
        # create the inverse operator
        inv_op = mne.minimum_norm.make_inverse_operator(evoked.info,
                                                        fwd,
                                                        noise_cov,
                                                        loose=loose,
                                                        depth=depth,
                                                        fixed=True)
        # apply the inverse
        # create an epoch object with the data
        # the events parameter here is fake.
        M_aug = np.zeros([q, evoked.info['nchan'], T])
        M_aug[:, sel, :] = M.copy()
        epochs = mne.EpochsArray(data=M_aug,
                                 info=evoked.info,
                                 events=np.ones([q, 3], dtype=np.int),
                                 tmin=evoked.times[0],
                                 event_id=None,
                                 reject=None)
        source_sol = mne.minimum_norm.apply_inverse_epochs(epochs,
                                                           inv_op,
                                                           lambda2=lambda2,
                                                           method=mne_method,
                                                           nave=1)
        m, T = source_sol[0].data.shape
        J_two_step = np.zeros([q, m, T])
        for r in range(q):
            J_two_step[r] = source_sol[r].data

        if method in [
                "mneFlip", "mneTrueL", "mneTrueLKronecker", "mneFlipKronecker"
        ]:
            U_two_step = np.zeros([q, n_ROI_valid, T])
            for i in range(n_ROI_valid):
                J_tmp = J_two_step[:, ROI_list[i], :]
                if method in ["mneFlip", "mneFlipKronecker"]:
                    tmp_nn = nn[ROI_list[i], :]
                    tmpu, tmpd, tmpv = np.linalg.svd(tmp_nn)
                    tmp_sign = np.sign(np.dot(tmp_nn, tmpv[0]))
                    U_two_step[:, i, :] = np.mean(J_tmp.transpose([0, 2, 1]) *
                                                  tmp_sign,
                                                  axis=-1)
                elif method in ["mneTrueL", "mneTrueLKronecker"]:
                    # U = (L^T L)^{-1} L^T J
                    tmp_true_L = L_list[i]
                    tmp_LTL = np.dot(tmp_true_L, tmp_true_L)
                    U_two_step[:, i, :] = (
                        np.dot(J_tmp.transpose([0, 2, 1]), tmp_true_L) /
                        tmp_LTL)
            if method in ["mneFlip", "mneTrueL"]:
                U_two_step0 = U_two_step[:, :, t_ind]
                Qu_hat = np.cov(U_two_step0.T)
                Tcov_hat = 0.0
            else:
                # kronecker
                Tcov_hat, Qu_hat = get_mle_kron_cov(U_two_step,
                                                    tol=1E-6,
                                                    MaxIter=100)

            diag0 = np.sqrt(np.diag(Qu_hat))
            denom = np.outer(diag0, diag0)
            corr_hat = np.abs(Qu_hat / denom)
        elif method in ["mnePairwise"]:
            Qu_hat, Tcov_hat = 0.0, 0.0
            corr_hat = np.eye(n_ROI_valid)
            for l1 in range(n_ROI_valid):
                for l2 in range(l1 + 1, n_ROI_valid):
                    J_tmp1 = J_two_step[:, ROI_list[l1], t_ind].T
                    J_tmp2 = J_two_step[:, ROI_list[l2], t_ind].T
                    tmp_corr = np.corrcoef(np.vstack([J_tmp1, J_tmp2]))
                    tmp_corr_valid = tmp_corr[0:J_tmp1.shape[0],
                                              J_tmp1.shape[0]::]
                    corr_hat[l1, l2] = np.mean(np.abs(tmp_corr_valid))
                    corr_hat[l2, l1] = corr_hat[l1, l2]
        result = dict(obj=0.0,
                      Qu_hat=Qu_hat,
                      Tcov_hat=Tcov_hat,
                      Sigma_J_list_hat=0.0,
                      L_list=0.0,
                      corr_hat=corr_hat,
                      method=method,
                      lambda2=lambda2)
    # save the result
    scipy.io.savemat(outname, result)
def save_visualized_jt(M,
                       noise_cov_path,
                       evoked_path,
                       Sigma_J_list,
                       ut,
                       ROI_list,
                       n_ROI_valid,
                       subjects_dir,
                       subj,
                       fwd_path,
                       out_stc_name,
                       out_fig_name,
                       whiten_flag,
                       depth=None,
                       force_fixed=True,
                       tmin=0,
                       tstep=0.01):
    """
    # ut and yt can be for a single time point
    """

    if depth == None:
        depth = 0.0

    q, n, T = M.shape
    # this function returns a list, take the first element
    evoked = mne.read_evokeds(evoked_path)[0]
    # depth weighting, TO BE MODIFIED
    print force_fixed
    fwd0 = mne.read_forward_solution(fwd_path,
                                     force_fixed=force_fixed,
                                     surf_ori=True)
    fwd = copy.deepcopy(fwd0)
    noise_cov = mne.read_cov(noise_cov_path)
    Sigma_E = noise_cov.data

    # orientation of dipoles
    ind0 = fwd['src'][0]['inuse']
    ind1 = fwd['src'][1]['inuse']
    # positions of dipoles
    rr = np.vstack(
        [fwd['src'][0]['rr'][ind0 == 1, :], fwd['src'][1]['rr'][ind1 == 1, :]])
    rr = rr / np.max(np.sum(rr**2, axis=1))
    nn = np.vstack(
        [fwd['src'][0]['nn'][ind0 == 1, :], fwd['src'][1]['nn'][ind1 == 1, :]])
    # number of dipoles
    m = rr.shape[0]
    all_ch_names = evoked.ch_names
    sel = [
        l for l in range(len(all_ch_names))
        if all_ch_names[l] not in evoked.info['bads']
    ]

    if force_fixed:
        print "fixed orient"
        print is_fixed_orient(fwd)
        if not is_fixed_orient(fwd):
            _to_fixed_ori(fwd)
        print "difference of G"
        print np.max(np.abs(fwd['sol']['data'])) / np.min(
            np.abs(fwd['sol']['data']))

    if whiten_flag:
        pca = True
        G, G_info, whitener, source_weighting, mask = _prepare_gain(
            fwd,
            evoked.info,
            noise_cov,
            pca=pca,
            depth=depth,
            loose=None,
            weights=None,
            weights_min=None)
        #Sigma_E_chol = np.linalg.cholesky(Sigma_E)
        #Sigma_E_chol_inv = np.linalg.inv(Sigma_E_chol)
        #G = np.dot(Sigma_E_chol_inv, G)
        # after whitening, the noise cov is assumed to identity
        #Sigma_E = (np.dot(Sigma_E_chol_inv, Sigma_E)).dot(Sigma_E_chol_inv.T)
        Sigma_E = np.eye(G.shape[0])
        M = (np.dot(whitener, M)).transpose([1, 0, 2])
    else:
        G = fwd['sol']['data'][sel, :]
        G_column_weighting = (np.sum(G**2, axis=0))**(depth / 2)
        G = G / G_column_weighting

    QJ = np.zeros(m)
    for l in range(len(ROI_list)):
        QJ[ROI_list[l]] = Sigma_J_list[l]

    L = np.zeros([m, n_ROI_valid])
    for l in range(n_ROI_valid):
        L[ROI_list[l], l] = 1.0

    if False:
        # compute the inverse
        inv_Sigma_E = np.linalg.inv(Sigma_E)
        GQE = G.T.dot(inv_Sigma_E)
        GQEG = GQE.dot(G)
        QJ_inv = 1.0 / QJ
        GQEG += np.diag(QJ_inv)
        inv_op = np.linalg.inv(GQEG)

        QJL = (L.T / QJ).T

    # =============debug =========================
    GQJ = G * QJ
    Q0 = Sigma_E + (GQJ).dot(G.T)
    invQ0 = np.linalg.inv(Q0)
    chol = np.linalg.cholesky(invQ0)

    # QJ G'(QE+GQJG')^{-1}
    operator_y = GQJ.T.dot(invQ0)
    # I - QJ G' (QE+ GQJG')^[-1} G
    operator_u = (np.eye(m) - operator_y.dot(G)).dot(L)

    # QJ - QJ G' (QE+ GQJG')^[-1} G GJ
    post_var = (np.eye(m) - operator_y.dot(G)) * QJ
    marg_std = np.sqrt(np.diag(post_var))

    if False:
        trial_ind = 0
        time_ind = 0
        plt.errorbar(range(m), J[trial_ind, :, time_ind], 2 * marg_std)
        plt.plot(range(m), J_true[trial_ind, :, time_ind])

    J = np.zeros([q, m, T])
    for r in range(q):
        #J[r] = inv_op.dot(np.dot(GQE, M[r]) + np.dot(QJL, ut[r]))
        J[r] = operator_y.dot(M[r]) + operator_u.dot(ut[r])

    LU = (np.dot(L, ut)).transpose([1, 0, 2])
    GLU = (np.dot(G.dot(L), ut)).transpose([1, 0, 2])

    plt.plot(J[0, :, 0].ravel(), LU[0, :, 0].ravel(), '.')

    # mne results
    evoked = mne.read_evokeds(evoked_path)[0]
    # depth weighting, TO BE MODIFIED
    noise_cov = mne.read_cov(noise_cov_path)

    ch_names = evoked.info['ch_names']
    # create the epochs first?
    M_all = np.zeros([q, len(ch_names), T])
    valid_channel_ind = [
        i for i in range(len(ch_names))
        if ch_names[i] not in evoked.info['bads']
    ]
    M_all[:, valid_channel_ind, :] = M.copy()
    events = np.ones([M.shape[0], 3], dtype=np.int)
    epochs = mne.EpochsArray(data=M_all,
                             info=evoked.info,
                             events=events,
                             tmin=evoked.times[0],
                             event_id=None,
                             reject=None)
    method = "MNE"
    lambda2 = 1.0
    depth0 = None if depth == 0 else depth
    inv_op = mne.minimum_norm.make_inverse_operator(evoked.info,
                                                    fwd,
                                                    noise_cov,
                                                    loose=0.0,
                                                    depth=depth0,
                                                    fixed=True)
    stcs = mne.minimum_norm.apply_inverse_epochs(epochs,
                                                 inv_op,
                                                 lambda2=lambda2,
                                                 method=method)

    J_mne = np.zeros([q, m, T])
    for r in range(q):
        J_mne[r] = stcs[r].data

    W = J - LU
    eta_hat = (operator_y.dot(M - GLU)).transpose([1, 0, 2])

    #
    # compute the std of J
    J_std = np.std(J, axis=0)
    u_std = np.std(ut, axis=0)
    J_mne_std = np.std(J_mne, axis=0)

    mat_dict = scipy.io.loadmat(simupath)
    trial_ind = 0
    tmp_J_list = [
        J[trial_ind], J_mne[trial_ind], J_std, J_mne_std,
        mat_dict['J'][trial_ind]
    ]
    tmp_u_list = [
        ut[trial_ind], ut[trial_ind], u_std, u_std, mat_dict['u'][trial_ind]
    ]
    suffix_list = ['trial%d' % trial_ind, "mne", "std", "mnestd", "truth"]
    # some other visulaizaation

    #    trial_ind = 0
    #    tmp_J_list = [J[trial_ind],  mat_dict['J'][trial_ind]]
    #    tmp_u_list = [ut[trial_ind], mat_dict['u'][trial_ind]]
    #    suffix_list = ['trial%d' % trial_ind, "truth"]

    times_in_ms = (np.arange(tmin, tmin + T * tstep, tstep)) * 1000.0

    for ll in range(5):
        tmp_J, tmp_u, suffix = tmp_J_list[ll], tmp_u_list[ll], suffix_list[ll]
        plt.figure()
        for l in range(n_ROI):
            ROI_id = 0
            _ = plt.subplot(n_ROI, 1, l + 1)
            _ = plt.plot(times_in_ms, tmp_J[ROI_list[l], :].T, 'b', alpha=0.1)
            if l < n_ROI_valid:
                _ = plt.plot(times_in_ms, tmp_u[l, :], 'k', lw=2, alpha=1)
                ROI_id = l + 1

            _ = plt.xlabel('time ms')
            _ = plt.title("ROI %d" % ROI_id)
        _ = plt.tight_layout()
        #_ = plt.savefig(out_fig_name + "%s.pdf" %suffix)

    # save as an STC
    vertices_to = [fwd['src'][0]['vertno'], fwd['src'][1]['vertno']]
    stc = mne.SourceEstimate(data=J_std,
                             vertices=vertices_to,
                             tmin=tmin,
                             tstep=tstep)
    stc.save(out_stc_name)

    # render the images

    clim = dict(kind='value', lims=np.array([0.1, 2, 10]) * 1E-10)
    time_seq = np.arange(0, T, 10)
    surface = "inflated"

    #    brain = stc.plot(surface= surface, hemi='both', subjects_dir=subjects_dir,
    #                    subject = subj,  clim=clim)
    #    for k in time_seq:
    #        brain.set_data_time_index(k)
    #        for view in ['ventral']:
    #            brain.show_view(view)
    #            im_name = out_fig_name + "%03dms_%s.pdf" \
    #               %(np.int(np.round(stc.times[k]*1000)), view)
    #            brain.save_image(im_name)
    #            print k
    #    brain.close()
    #
    #    for hemi in ['lh','rh']:
    #        brain = stc.plot(surface=surface, hemi= hemi, subjects_dir=subjects_dir,
    #                subject = subj,  clim=clim)
    #        for k in time_seq:
    #            brain.set_data_time_index(k)
    #            for view in ['medial','lateral']:
    #                brain.show_view(view)
    #                im_name = out_fig_name + "%03dms_%s_%s.pdf" \
    #               %(np.int(np.round(stc.times[k]*1000)), view, hemi)
    #                brain.save_image(im_name)
    #        brain.close()

    return 0
def get_estimate_ks(M, ROI_list, n_ROI_valid, fwd_path, evoked_path, noise_cov_path, out_name, 
                 prior_Q0 = None, prior_Q = None, prior_sigma_J_list = None, 
                 prior_A = None,
                 depth = None,
                 MaxIter0 = 100, MaxIter = 50, MaxIter_coarse = 10,
                 tol0 = 1E-4, tol = 1E-2,
                 verbose0 = True, verbose = False, verbose_coarse = True,
                 L_flag = False,
                 whiten_flag = True,
                 n_ini= 0, n_pool = 2, flag_A_time_vary = False, use_pool = False,
                 ini_Gamma0_list = None, ini_A_list = None, ini_Gamma_list = None,
                 ini_sigma_J_list = None, force_fixed=True, flag_inst_ini = True,
                 a_ini = 0.1):
    """
    Inputs: 
        M, [q, n_channels, n_times] sensor data
        ROI_list, ROI indices list
        fwd_path, full path of the forward solution
        evoked_path, full path of the evoked template
        noise_cov_path, full path of the noise covariance
        out_name, full path of the mat name to save
        
        # actually due to scale issues, no depth weighting should be allowed in the simulation. 
        # because normalizing G will result in strong violation of source generation assumptions
        priors:
        prior_Q0, prior_Q, prior_sigma_J_list, not implemented, may be inverse gamma or gamma
        prior_A, dict(lambda0 = 0.0, lambda1 = 1.0)
        
        depth: forward weighting parameter
        verbose:
        whiten_flag: if True, whiten the data, so that sensor error is identity  
        n_ini, number of random initializations
        
        # list of initial values,
        # ini_Gamma0_list, ini_A_list, ini_Gamma_list, ini_sigma_J_list must have the same length

    """
    if depth == None:
        depth = 0.0
    
    q,_,T0 = M.shape
    T = T0-1
    # this function returns a list, take the first element
    evoked = mne.read_evokeds(evoked_path)[0]
    # depth weighting, TO BE MODIFIED
    print force_fixed
    fwd0 = mne.read_forward_solution(fwd_path, force_fixed= force_fixed, surf_ori = True)
    fwd= copy.deepcopy(fwd0)
    noise_cov = mne.read_cov(noise_cov_path)
    Sigma_E = noise_cov.data
    
    # orientation of dipoles
    ind0 = fwd['src'][0]['inuse']
    ind1 = fwd['src'][1]['inuse']
    # positions of dipoles
    rr = np.vstack([fwd['src'][0]['rr'][ind0==1,:], 
                             fwd['src'][1]['rr'][ind1==1,:]])
    rr = rr/np.max(np.sum(rr**2, axis = 1))                       
    nn = np.vstack([fwd['src'][0]['nn'][ind0 == 1,:],
                    fwd['src'][1]['nn'][ind1 == 1,:]])
    # number of dipoles                
    #m = rr.shape[0]
    all_ch_names = evoked.ch_names
    sel = [l for l in range(len(all_ch_names)) if all_ch_names[l] not in evoked.info['bads']]
    print "fixed orient"
    print is_fixed_orient(fwd)
    if not is_fixed_orient(fwd):
        _to_fixed_ori(fwd)
    print "difference of G"
    print np.max(np.abs(fwd['sol']['data']))/np.min(np.abs(fwd['sol']['data']))
    
    if whiten_flag:
        pca = True
        G, G_info, whitener, source_weighting, mask = _prepare_gain(fwd, evoked.info,
                                                                    noise_cov, pca =pca,
                                                                    depth = depth, loose = None,
                                                                    weights = None, weights_min = None)
        #Sigma_E_chol = np.linalg.cholesky(Sigma_E)
        #Sigma_E_chol_inv = np.linalg.inv(Sigma_E_chol)
        #G = np.dot(Sigma_E_chol_inv, G)
        # after whitening, the noise cov is assumed to identity
        #Sigma_E = (np.dot(Sigma_E_chol_inv, Sigma_E)).dot(Sigma_E_chol_inv.T)
        Sigma_E = np.eye(G.shape[0])
        M = (np.dot(whitener, M)).transpose([1,0,2])
    else:
        G = fwd['sol']['data'][sel,:]
        G_column_weighting = (np.sum(G**2, axis = 0))**(depth/2)
        G = G/G_column_weighting
    
    # prior for L
    L_list_param = 1.5 # a exp (-b ||x-y||^2)
    Q_L_list = list()
    for i in range(n_ROI_valid):
        tmp_n = len(ROI_list[i])
        tmp = np.zeros([tmp_n, tmp_n])
        for i0 in range(tmp_n):
            for i1 in range(tmp_n):
                tmp[i0,i1] = np.dot(nn[i0,:], nn[i1,:])* np.exp(-L_list_param * (np.sum((rr[i0,:]-rr[i1,:])**2)))
        #print np.linalg.cond(tmp)       
        Q_L_list.append(tmp)
    prior_L_precision = copy.deepcopy(Q_L_list)
    for i in range(n_ROI_valid):
        prior_L_precision[i] = np.linalg.inv(Q_L_list[i]) 
    
    y_array = M.transpose([0,2,1]) # q,T,n    
    scale_factor = 1E-9
    p = n_ROI_valid
    
    L_list_0 = list()
    for i in range(n_ROI_valid):
        L_list_0.append(np.ones(ROI_list[i].size))
        
    # default param list, A being all zero
    
    ini_param_list = list()        
    Gamma0_0 = np.eye(p)*scale_factor
    Gamma_0 = np.eye(p)*scale_factor
    if flag_A_time_vary:
        A_0 = np.zeros([T,p,p])
        for t in range(T):
            A_0[t] = np.eye(p)*a_ini
    else:
        A_0 = np.eye(p)*a_ini
    sigma_J_list_0 = np.ones(p)*scale_factor
    
    if ini_Gamma0_list is None:
        ini_Gamma0_list = list() 
    if ini_A_list is None:
        ini_A_list = list() 
    if ini_Gamma_list is None:    
        ini_Gamma_list = list() 
    if ini_sigma_J_list is None:
        ini_sigma_J_list = list() 
    
    #if n_ini >= 0, append a new initialization, else do not
    if n_ini >= 0:
        ini_Gamma0_list.append(Gamma0_0)
        ini_A_list.append(A_0)
        ini_Gamma_list.append(Gamma_0)
        ini_sigma_J_list.append(sigma_J_list_0)
        ini_param_list = list()
        
    for l1 in range(len(ini_Gamma0_list)):
       ini_param_list.append(dict(y_array=y_array, G=G, ROI_list =ROI_list,
                        Sigma_E = Sigma_E, 
                        Gamma0_0 = ini_Gamma0_list[l1], A_0 = ini_A_list[l1], 
                        Gamma_0=  ini_Gamma_list[l1], sigma_J_list_0 = ini_sigma_J_list[l1], 
                        L_list_0 = L_list_0, flag_A_time_vary = flag_A_time_vary,
                        prior_Q0 = prior_Q0, prior_A = prior_A, prior_Q = prior_Q,
                        prior_L_precision = prior_L_precision, 
                        prior_sigma_J_list = prior_sigma_J_list,
                        MaxIter0 = MaxIter0, tol0 = tol0, verbose0 = False,
                        MaxIter = MaxIter_coarse, tol = tol, verbose = verbose_coarse, 
                        L_flag = L_flag))
    
    # second initialization, least squares
    m = G.shape[1]
    L = np.zeros([m, n_ROI_valid])
    for i in range(n_ROI_valid):
        L[ROI_list[i], i] = L_list_0[i]
    C = G.dot(L) 
    R0 = Sigma_E.copy() 
    for l in range(len(sigma_J_list_0)):
        R0 += sigma_J_list_0[l]**2 *  G[:, ROI_list[l]].dot(G[:, ROI_list[l]].T)      
    u_array_hat = get_lsq_u(y_array, R0,C)
    # set priors all to None,  avoid coordinate decent to get the global solution
    Gamma0_ls, A_ls, Gamma_ls = get_param_given_u(u_array_hat, Gamma0_0, A_0, Gamma_0, 
       flag_A_time_vary = flag_A_time_vary,
       prior_Q0 = None,  prior_A = None, prior_Q = None,
       MaxIter0 = MaxIter0, tol0 = tol0, verbose0 = verbose0,
       MaxIter = MaxIter, tol = tol, verbose = verbose) 
    # debug 
    print "Gamma0_ls and Gamma_ls"
    print Gamma0_ls
    print Gamma_ls 
    
    if n_ini >= 0:      
        ini_param_list.append(dict(y_array=y_array, G=G, ROI_list =ROI_list,
                        Sigma_E = Sigma_E, 
                        Gamma0_0 = Gamma0_ls, A_0 = A_ls, Gamma_0= Gamma_ls, 
                        sigma_J_list_0 = sigma_J_list_0, L_list_0 = L_list_0,
                        flag_A_time_vary = flag_A_time_vary,
                        prior_Q0 = prior_Q0, prior_A = prior_A, prior_Q = prior_Q,
                        prior_L_precision = prior_L_precision, 
                        prior_sigma_J_list = prior_sigma_J_list,
                        MaxIter0 = MaxIter0, tol0 = tol0, verbose0 = False,
                        MaxIter = MaxIter_coarse, tol = tol, verbose = verbose_coarse, L_flag = L_flag))

    if flag_inst_ini: # run the instantaneous model to get initialization for Q and sigma_J_list
        print "initilization using my instantaneous model"        
        t_ind = 1
        MMT = M[:,:,t_ind].T.dot(M[:,:,t_ind])
        Qu0 = np.eye(p)*scale_factor**2
        Sigma_J_list0 = np.ones(len(ROI_list))*scale_factor**2
        # these parames are not used,
        alpha, beta = 1.0, 1.0; nu = p +1; V_inv = np.eye(p)*1E-4; eps = 1E-13;
        inv_Q_L_list = list()
        for i in range(n_ROI_valid):
            inv_Q_L_list.append(np.eye(len(ROI_list[i])))    
        Qu_hat0, Sigma_J_list_hat0, L_list_hat, obj = inst.get_map_coor_descent(
                            Qu0, Sigma_J_list0, L_list_0,
                          ROI_list, G, MMT, q, Sigma_E,
                          nu, V_inv, inv_Q_L_list, alpha, beta, 
                          prior_Q = False, prior_Sigma_J = False, prior_L = False ,
                          Q_flag = True, Sigma_J_flag = True, L_flag = False,
                          tau = 0.8, step_ini = 1.0, MaxIter = MaxIter, tol = tol,
                          eps = eps, verbose = verbose, verbose0 = verbose0, 
                          MaxIter0 = MaxIter0, tol0 = tol0)
        print Qu_hat0, Sigma_J_list_hat0
        ini_param_list.append(dict(y_array=y_array, G=G, ROI_list =ROI_list,
                        Sigma_E = Sigma_E, 
                        Gamma0_0 = np.linalg.cholesky(Qu_hat0), A_0 = A_0, 
                        Gamma_0 =  np.linalg.cholesky(Qu_hat0), 
                        sigma_J_list_0 = np.sqrt(Sigma_J_list_hat0),
                        L_list_0 = L_list_0, flag_A_time_vary = flag_A_time_vary,
                        prior_Q0 = prior_Q0, prior_A = prior_A, prior_Q = prior_Q,
                        prior_L_precision = prior_L_precision, 
                        prior_sigma_J_list = prior_sigma_J_list,
                        MaxIter0 = MaxIter0, tol0 = tol0, verbose0 = False,
                        MaxIter = MaxIter_coarse, tol = tol, verbose = verbose_coarse, 
                        L_flag = L_flag))
               
    if n_ini > 0 and flag_A_time_vary :
        # cut the time into n_ini segments evenly, compute the fixed A, and then concatenate them
        time_ind_dict_list = list() # each element is a dict, including l and time_ind_list_tmp
        for l in range(n_ini): # 1+2+..+ n_ini
            # segmant y_array!
            n_time_per_segment = (T+1)//(l+1)
            if l == 0:
                time_ind_dict_list.append(dict(l = l,time_ind =range(T+1)))
            else:
                for l0 in range(l):
                    time_ind =range(l0*n_time_per_segment, (l0+1)*n_time_per_segment+1)
                    time_ind_dict_list.append(dict(l = l, time_ind = time_ind))
                time_ind_dict_list.append(dict(l = l, time_ind =range((l0+1)*n_time_per_segment, T+1)))
        # Gamma0, Gamm0_0, L_list_0, sigma_J_list_0, are already defined
        tmp_A0 = np.eye(p)*0.9     
        ini_param_fixed_A_list = list()
        for l0 in range(len(time_ind_dict_list)):
            print l0
            y_array_tmp = y_array[:,time_ind_dict_list[l0]['time_ind'],:]
            print y_array_tmp.shape
            ini_param_fixed_A_list.append(dict(y_array=y_array_tmp, G=G, ROI_list =ROI_list,
                        Sigma_E = Sigma_E, 
                        Gamma0_0 = Gamma0_0, A_0 = tmp_A0, Gamma_0= Gamma_0, 
                        sigma_J_list_0 = sigma_J_list_0, L_list_0 = L_list_0,
                        flag_A_time_vary = False,
                        prior_Q0 = prior_Q0, prior_A = prior_A, prior_Q = prior_Q,
                        prior_L_precision = prior_L_precision, 
                        prior_sigma_J_list = prior_sigma_J_list,
                        MaxIter0 = MaxIter0, tol0 = tol0, verbose0 = False,
                        MaxIter = MaxIter_coarse, tol = tol, verbose = verbose_coarse, L_flag = L_flag))
        # solve the individual 
        if use_pool:
            pool = Pool(n_pool)
            result_fixed_list = pool.map(use_EM, ini_param_fixed_A_list)
            pool.close() 
        else:
            result_fixed_list = list()
            for l0 in range(len(ini_param_fixed_A_list)):
                print "fixed %d th ini_param" %l0
                result_fixed_list.append(use_EM(ini_param_fixed_A_list[l0]))
        # combine new A0_piecewise, add them to param list
        for l in range(n_ini):
            relevant_ind = [l0 for l0 in range(len(time_ind_dict_list))
                          if time_ind_dict_list[l0]['l'] == l]
            tmp_A0 = np.zeros([T,p,p])
            tmp_Q0 = np.zeros([p,p])
            for l0 in relevant_ind:
                tmp_time_ind = time_ind_dict_list[l0]['time_ind']
                for t0 in tmp_time_ind[1::]:
                    tmp_A0[t0-1,:,:] = result_fixed_list[l0]['A']
                tmp_Gamma = result_fixed_list[l0]['Gamma']
                tmp_Q0 += tmp_Gamma.dot(tmp_Gamma)
            tmp_Q0 /= np.float(len(relevant_ind))
            ini_param_list.append(dict(y_array=y_array, G=G, ROI_list =ROI_list,
                        Sigma_E = Sigma_E,  Gamma0_0 = Gamma0_0, A_0 = tmp_A0, 
                        Gamma_0= np.linalg.cholesky(tmp_Q0), 
                        sigma_J_list_0 = sigma_J_list_0, L_list_0 = L_list_0,
                        flag_A_time_vary = flag_A_time_vary,
                        prior_Q0 = prior_Q0, prior_A = prior_A, prior_Q = prior_Q,
                        prior_L_precision = prior_L_precision, 
                        prior_sigma_J_list = prior_sigma_J_list,
                        MaxIter0 = MaxIter0, tol0 = tol0, verbose0 = False,
                        MaxIter = MaxIter_coarse, tol = tol, verbose = False, L_flag = L_flag))
    # after obtaining the multiple starting points, solve them with a few iterations                           
    # try parallel processing
    print "optimizing %d initializations" % len(ini_param_list)
    if use_pool:
        print "using pool"
        pool = Pool(n_pool)
        result_list = pool.map(use_EM, ini_param_list)
        pool.close() 
    else:
        result_list = list()
        for l in range(len(ini_param_list)):
            result_list.append(use_EM(ini_param_list[l]))
 
    obj_all = np.zeros(len(result_list))
    for l in range(len(result_list)):
        obj_all[l] = result_list[l]['obj']
    print obj_all
    i_star = np.argmin(obj_all)
    
    ini_param = dict(y_array=y_array, G=G, ROI_list =ROI_list,
                        Sigma_E = Sigma_E, 
                        Gamma0_0 = result_list[i_star]['Gamma0'],
                        A_0 = result_list[i_star]['A'], 
                        Gamma_0=  result_list[i_star]['Gamma'], 
                        sigma_J_list_0 = result_list[i_star]['sigma_J_list'], 
                        L_list_0 = result_list[i_star]['L_list'],
                        flag_A_time_vary = flag_A_time_vary,
                        prior_Q0 = prior_Q0, prior_A = prior_A, prior_Q = prior_Q,
                        prior_L_precision = prior_L_precision, 
                        prior_sigma_J_list = prior_sigma_J_list,
                        MaxIter0 = MaxIter0, tol0 = tol0, verbose0 = verbose0,
                        MaxIter = MaxIter, tol = tol, verbose = verbose, L_flag = L_flag)

    result0 = use_EM(ini_param)  
    Gamma0_hat, A_hat, Gamma_hat  = result0['Gamma0'], result0['A'], result0['Gamma']
    sigma_J_list_hat   = result0['sigma_J_list']             
    L_list_hat = result0['L_list']
    print result0['obj']              
    result = dict(Q0_hat = Gamma0_hat.dot(Gamma0_hat.T),
                  Q_hat = Gamma_hat.dot(Gamma_hat.T),
                  A_hat = A_hat,
                  Sigma_J_list_hat = sigma_J_list_hat**2,
                  L_list_hat = L_list_hat,
                  u_array_hat = result0['u_t_T_array'], obj = result0['obj'])
    scipy.io.savemat(out_name, result)
        
Beispiel #14
0
def test_io_forward():
    """Test IO for forward solutions
    """
    temp_dir = _TempDir()
    # do extensive tests with MEEG + grad
    n_channels, n_src = 366, 108
    fwd = read_forward_solution(fname_meeg_grad)
    assert_true(isinstance(fwd, Forward))
    fwd = read_forward_solution(fname_meeg_grad)
    fwd = convert_forward_solution(fwd, surf_ori=True)
    leadfield = fwd['sol']['data']
    assert_equal(leadfield.shape, (n_channels, n_src))
    assert_equal(len(fwd['sol']['row_names']), n_channels)
    fname_temp = op.join(temp_dir, 'test-fwd.fif')
    write_forward_solution(fname_temp, fwd, overwrite=True)

    fwd = read_forward_solution(fname_meeg_grad)
    fwd = convert_forward_solution(fwd, surf_ori=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read, surf_ori=True)
    leadfield = fwd_read['sol']['data']
    assert_equal(leadfield.shape, (n_channels, n_src))
    assert_equal(len(fwd_read['sol']['row_names']), n_channels)
    assert_equal(len(fwd_read['info']['chs']), n_channels)
    assert_true('dev_head_t' in fwd_read['info'])
    assert_true('mri_head_t' in fwd_read)
    assert_array_almost_equal(fwd['sol']['data'], fwd_read['sol']['data'])

    fwd = read_forward_solution(fname_meeg)
    fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                   use_cps=False)
    write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read, surf_ori=True,
                                        force_fixed=True, use_cps=False)
    assert_true(repr(fwd_read))
    assert_true(isinstance(fwd_read, Forward))
    assert_true(is_fixed_orient(fwd_read))
    compare_forwards(fwd, fwd_read)

    fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                   use_cps=True)
    leadfield = fwd['sol']['data']
    assert_equal(leadfield.shape, (n_channels, 1494 / 3))
    assert_equal(len(fwd['sol']['row_names']), n_channels)
    assert_equal(len(fwd['info']['chs']), n_channels)
    assert_true('dev_head_t' in fwd['info'])
    assert_true('mri_head_t' in fwd)
    assert_true(fwd['surf_ori'])
    write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read, surf_ori=True,
                                        force_fixed=True, use_cps=True)
    assert_true(repr(fwd_read))
    assert_true(isinstance(fwd_read, Forward))
    assert_true(is_fixed_orient(fwd_read))
    compare_forwards(fwd, fwd_read)

    fwd = read_forward_solution(fname_meeg_grad)
    fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                   use_cps=True)
    leadfield = fwd['sol']['data']
    assert_equal(leadfield.shape, (n_channels, n_src / 3))
    assert_equal(len(fwd['sol']['row_names']), n_channels)
    assert_equal(len(fwd['info']['chs']), n_channels)
    assert_true('dev_head_t' in fwd['info'])
    assert_true('mri_head_t' in fwd)
    assert_true(fwd['surf_ori'])
    write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read, surf_ori=True,
                                        force_fixed=True, use_cps=True)
    assert_true(repr(fwd_read))
    assert_true(isinstance(fwd_read, Forward))
    assert_true(is_fixed_orient(fwd_read))
    compare_forwards(fwd, fwd_read)

    # test warnings on bad filenames
    fwd = read_forward_solution(fname_meeg_grad)
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        fwd_badname = op.join(temp_dir, 'test-bad-name.fif.gz')
        write_forward_solution(fwd_badname, fwd)
        read_forward_solution(fwd_badname)
    assert_naming(w, 'test_forward.py', 2)

    fwd = read_forward_solution(fname_meeg)
    write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    compare_forwards(fwd, fwd_read)
def get_STFT_R_solution(evoked_list,X, fwd_list0, G_ind, noise_cov,
                                label_list,  GroupWeight_Param,
                                active_set_z0, 
                                alpha_seq,beta_seq,gamma_seq,
                                loose= None, depth=0.0, maxit=500, tol=1e-4,
                                wsize=16, tstep=4, window=0.02,
                                L2_option = 0, delta_seq = None,
                                coef_non_zero_mat = None, Z0_l2 = None,
                                Maxit_J=10, Incre_Group_Numb=50, dual_tol=0.01,
                                Flag_backtrack = True, L0 = 1.0, eta = 1.5,
                                Flag_verbose = False,
                                Flag_nonROI_L2 = False):
    '''
    Compute the L21 or L2 inverse solution of the stft regression. 
    If Flag_trial_by_trial == True, use the "trial-by-trial" model for estiamtion,
    otherwise, use the simpler model without trial by trial terms
    Input:
        evoked_list, a list of evoked objects
        X, [n_trials, p] design matrix of the regresison
        fwd_list0, a list of n_run  forward solution object
        run_ind, [n_trials, ] run index, starting from zero
        noise_cov, the noise covariance matrix
        label_list, a list of labels or ROIs. 
                    it can be None, in that case, each individual dipole is
                    one group, also, GroupWeight_Param becomes invalid, 
                    penalty alpha is applied to every dipole, Flag_nonROI_L2 is
                    set to False too. 
        GroupWeight_param, a ratio of weights  within ROIs / outside ROIs
                       Group weights = 1/ n_dipoles in the group, times ratio,
                       then normalized
        active_set_z0, the initial active_set
        alpha_seq, tuning sequence for alpha, (the group penalty)
        beta_seq, tuning sequence for beta, ( penalty for a single STFT basis function )
        loose, depth, the loose and depth paramter for the source space
        maxit, the maximum number of iteration
        tol, numerical tolerance of the optimizaiton
        wsize, window size of the STFT 
        tstep, time steps of the STFT
        window, windowing of the data, just to remove edge effects
        L2_option, 0, only compute the L21 solution
                   1, after computing the L21 solution, 
                       use them as the active set and get an L2 solution.
                       If delta_seq is provided, run cross validation 
                                    to get the best tuning parameter.
                   2, only compute the L2 solution, 
                      coef_non_zero_mat must not be None for this option,
                      active_set_z0, active_t_ind must correspond to the active set
        delta_seq, the tuning sequence for the L2 solution
                   if None, a default value will be used. 
        coef_non_zero_mat, [active_set.sum(), n_coefs*p], boolean matrix, active set
            e.g. coef_non_zero_mat = np.abs(Z)>0
        Z0_l2, the same size as coef_non_zero_mat, the initial value for L2 problems
        verbose, mne-python parameter, level of verbose
        Flag_nonROI_L2 = False, if true, all dipoles outside the ROIs are one large group. 
        Maxit_J, when solving the L21 problem, maximum number of greedy steps to take in the active-set gready method  
        Incre_Group_Numb: when solving the L21 problem, in the greedy step, each time include this number of first-level groups
        dual_tol: when solving the L21 problem,, if the violation of KKT for the greedy method is smaller than this value, stop
        depth, 0 to 1, the depth prior defined in the MNE algorithm, it normalizes the forward matrix, 
               by dividing each column with  (np.sum(G**2, axis = 0))**depth, such that deeper source points can 
               larger influence.
               To make it valid, the input forward objects must not have fixed orientation!
        Flag_verbose,   whether to print the optimization details of solving L21.  
        Flag_backtrack = True, L0 = 1.0, eta = 1.5,  parameters for backtracking          
       
    Output:
        Z_full, [n_dipoles, n_coefs*p], complex matrix, the regression results
        active_set, [n_dipoles,] boolean array, dipole active set
        active_t_ind, [n_step,], boolean array, temporal active set, should be a full True vector
        stc_list, a list of stc objects, the source solutions
        alpha_star, the best alpha
        beta_star, the best beta
        gamma_star, the best gamma
        delta_star, the best delta
    '''
    # =========================================================================    
    # some parameters to prepare the forward solution
    weights, weights_min, pca=None, None, True 
    all_ch_names = evoked_list[0].ch_names
    info = evoked_list[0].info
    n_trials = len(evoked_list)
    # put the forward solution in fixed orientation if it's not already
    n_runs = len(np.unique(G_ind))
    G_list = list()
    whitener_list = list()
    fwd_list = deepcopy(fwd_list0)
    for run_id in range(n_runs):
        if loose is None and not is_fixed_orient(fwd_list[run_id]):
            # follow the tf_mixed_norm
            _to_fixed_ori(fwd_list[run_id])
        
        # mask should be None
        gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
             fwd_list[run_id], info, noise_cov, pca, depth, loose, weights, weights_min)                                                                         
        G_list.append(gain)
        whitener_list.append(whitener)                                    
    # to debug
    # print np.linalg.norm(G_list[0]-G_list[1])/np.linalg.norm(G_list[0])
    # print np.linalg.norm(whitener_list[0]-whitener_list[1])
    # the whitener is the same across runs
    # apply the window to the data
    if window is not None:
        for r in range(n_trials):
            evoked_list[r] = _window_evoked(evoked_list[r], window)
    # prepare the sensor data
    sel = [all_ch_names.index(name) for name in gain_info["ch_names"]]
    _, n_times = evoked_list[0].data[sel].shape
    n_sensors = G_list[0].shape[0]
    
    M = np.zeros([n_sensors, n_times, n_trials], dtype = np.float)
    # Whiten data
    logger.info('Accessing and Whitening data matrix.')
    # deal with SSP
    # the projector information should be applied to Y
    info = evoked_list[0].info
    # all forward solutions must hav ethe same channels, 
    # if there are bad channels, make sure to remove them for all trials before using this function
    fwd_ch_names = [c['ch_name'] for c in fwd_list[0]['info']['chs']]
    ch_names = [c['ch_name'] for c in info['chs']
                if (c['ch_name'] not in info['bads']
                    and c['ch_name'] not in noise_cov['bads'])
                and (c['ch_name'] in fwd_ch_names
                     and c['ch_name'] in noise_cov.ch_names)]
    # ?? There is no projection in the 0.11 version, should I remove this too
    # proj should be None, since the projection should be applied after epoching
    proj, _, _ = mne.io.proj.make_projector(info['projs'], ch_names)
    for r in range(n_trials):
        M[:,:,r] = reduce(np.dot,[whitener,proj, evoked_list[r].data[sel]])
    #=========================================================================
    # Create group information
    src = fwd_list[0]['src']
    n_dip_per_pos = 1 if is_fixed_orient(fwd_list[0]) else 3
    # number of actual nodes, each node can be associated with 3 dipoles
    n_dipoles = G_list[0].shape[1]//n_dip_per_pos
    ## this function is only for n_dip_per_pos == 1
    #if n_dip_per_pos != 1:
    #    raise ValueError("n_orientation must be 1 for now!")
    ##
    if label_list is None:
        nROI = 0
        Flag_nonROI_L2 = False
    else:
        label_ind = list()
        for label in label_list:
            # get the column index corresponding to the ROI
            _, tmp_sel = label_src_vertno_sel(label,src)                                       
            label_ind.append(tmp_sel) 
        nROI = len(label_ind)
                                      
    DipoleGroup = list()
    isinROI = np.zeros(n_dipoles, dtype = np.bool)
    if n_dip_per_pos == 1:
        for i in range(nROI):
            DipoleGroup.append((np.array(label_ind[i])).astype(np.int)) 
            isinROI[label_ind[i]] = True
        # dipoles outside the ROIs
        notinROI_ind = np.nonzero(isinROI==0)[0]
        if Flag_nonROI_L2:
            DipoleGroup.append(notinROI_ind.astype(np.int))           
        else:
            for i in range(len(notinROI_ind)):
                DipoleGroup.append(np.array([notinROI_ind[i]]))
    else:
        for i in range(nROI):
            tmp_ind = np.array(label_ind[i])
            tmp_ind = np.hstack([tmp_ind*3,
                             tmp_ind*3+1, 
                             tmp_ind*3+2])
            DipoleGroup.append(tmp_ind.astype(np.int)) 
            isinROI[tmp_ind] = True
        # dipoles outside the ROIs
        notinROI_ind = np.nonzero(isinROI==0)[0]
        if Flag_nonROI_L2:
            DipoleGroup.append(notinROI_ind.astype(np.int))     
        else:   
            for i in range(len(notinROI_ind)):
                DipoleGroup.append(np.array([3*notinROI_ind[i], 
                                             3*notinROI_ind[i]+1,
                                             3*notinROI_ind[i]+2]).astype(np.int))  
    # Group weights, weighted by number of dipoles in the group  
    DipoleGroupWeight = 1.0/np.array([len(x) for x in DipoleGroup ])
    DipoleGroupWeight[0:nROI] *= GroupWeight_Param
    DipoleGroupWeight /= DipoleGroupWeight.sum()
        
    # =========================================================================
    # STFT constants
    n_step = int(np.ceil(n_times/float(tstep)))
    n_freq = wsize// 2+1
    n_coefs = n_step*n_freq
    p = X.shape[1]
    # =========================================================================
    # Scaling to make setting of alpha easy, modified from tf_mixed_norm in v0.11
    alpha_max = norm_l2inf(np.dot(G_list[0].T, M[:,:,0]), 
                           n_dip_per_pos, copy=False)
    alpha_max *= 0.01
    for run_id in range(n_runs):
        G_list[run_id] /= alpha_max
    # mne v0.11  tf_mixed_norm,  "gain /= alpha_max    source_weighting /= alpha_max"
    # so maybe the physcial meaning of source_weighting changed to its inverse
    # i.e. G_tilde = G*source_weighting
    # for MNE0.8, I used
    #source_weighting *= alpha_max 
    source_weighting /= alpha_max
    cv_partition_ind = np.zeros(n_trials)
    cv_partition_ind[1::2] = 1
    cv_MSE_lasso, cv_MSE_L2 = 0,0
    # =========================================================================
    if L2_option == 0 or L2_option == 1: 
        #  compute the L21 solution
        # setting the initial values, make sure ROIs are in the initial active set
        isinROI_ind = np.nonzero(isinROI)[0]
        if n_dip_per_pos == 1:
            active_set_z0[isinROI_ind] = True
        else:
            active_set_z0[3*isinROI_ind ] = True
            active_set_z0[3*isinROI_ind+1] = True
            active_set_z0[3*isinROI_ind+2] = True
            
        active_set_J_ini = np.zeros(len(DipoleGroup), dtype = np.bool)
        for l in range(len(DipoleGroup)):
            if np.sum(active_set_z0[DipoleGroup[l]]) > 0:
                active_set_J_ini[l] = True
        # if alpha and beta are sequences, use cross validation to select the best
        if len(alpha_seq) > 1 or len(beta_seq) > 1 or len(gamma_seq) >1:
            print "select alpha,beta and gamma"
            alpha_star, beta_star, gamma_star, cv_MSE_lasso = L21solver.select_alpha_beta_gamma_stft_tree_group_cv_active_set(
                                         M,G_list, G_ind, X,
                                         active_set_J_ini, 
                                         DipoleGroup,DipoleGroupWeight,
                                         alpha_seq, beta_seq, gamma_seq, cv_partition_ind,
                                         n_orient=n_dip_per_pos, 
                                         wsize=wsize, tstep = tstep, 
                                         maxit=maxit, tol = tol,
                                         Maxit_J = Maxit_J, Incre_Group_Numb = Incre_Group_Numb,
                                         dual_tol = dual_tol,
                                         Flag_backtrack = Flag_backtrack, L0 = L0, eta = eta,
                                         Flag_verbose=Flag_verbose)
        else:
            alpha_star, beta_star, gamma_star = alpha_seq[0], beta_seq[0], gamma_seq[0]
        # randomly initialize Z0, make sure the imaginary part is zero
        Z0 = np.zeros([active_set_z0.sum(), n_coefs*p])*1j \
                + np.random.randn(active_set_z0.sum(), n_coefs*p)*1E-20
        tmp_result = L21solver.solve_stft_regression_tree_group_active_set(
                                M, G_list, G_ind, X, 
                                alpha_star, beta_star, gamma_star,
                                DipoleGroup, DipoleGroupWeight, 
                                Z0, active_set_z0, 
                                active_set_J_ini, n_orient=n_dip_per_pos, 
                                wsize=wsize, tstep=tstep, maxit=maxit, tol=tol,
                                Maxit_J=Maxit_J, Incre_Group_Numb=Incre_Group_Numb,
                                dual_tol=dual_tol, 
                                Flag_backtrack = Flag_backtrack, L0 = L0, eta = eta,
                                Flag_verbose=Flag_verbose)        
        if tmp_result is None:
            raise Exception("No active dipoles found. alpha is too big.")
        Z = tmp_result['Z']
        active_set = tmp_result['active_set']
        active_t_ind = np.ones(n_step, dtype = np.bool)
        # the following part is copied from tf_mixed_norm in v0.11
        if mask is not None:
            active_set_tmp = np.zeros(len(mask), dtype=np.bool)
            active_set_tmp[mask] = active_set
            active_set = active_set_tmp
            del active_set_tmp
        
    # =====================================================================
    delta_star = None # even if L2_option ==0, we will stil return an empty delta_star
    #re-run the regression with a given active set
    if L2_option == 1 or L2_option == 2:
        # if only L2 solution is needed, do some initialization,
        if L2_option == 2: 
            if coef_non_zero_mat is None:
                raise ValueError("if L2_option == 2, coef_non_zero_mat must not be empty!")
            active_set= active_set_z0.copy()
            active_t_ind = np.ones(n_step, dtype = np.bool)
            if Z0_l2 is None:
                # make sure the imaginary part is zero
                Z = np.zeros([active_set_z0.sum(), n_coefs*p])*1j \
                + np.random.randn(active_set_z0.sum(), n_coefs*p)*1E-20
            else:
                Z = Z0_l2
            alpha_star, beta_star, gamma_star = None, None, None
        if L2_option == 1:
            coef_non_zero_mat = np.abs(Z)>0
        if delta_seq is None:
            delta_seq = np.array([1E-12,1E-10,1E-8])
        if len(delta_seq) > 1:
            Z0 = Z.copy()
            Z0 = Z0[:, np.tile(active_t_ind,p*n_freq)]
            delta_star, cv_MSE_L2 = L2solver.select_delta_stft_regression_cv(M,G_list, G_ind, X,
                                                  Z0, active_set, active_t_ind,
                                                  coef_non_zero_mat,
                                                delta_seq,cv_partition_ind,
                                            wsize=wsize, tstep = tstep, 
                                            maxit=maxit, tol = tol,
                                            Flag_backtrack = Flag_backtrack, L0 = L0, eta = eta,
                                            Flag_verbose = Flag_verbose)
        else:
            delta_star = delta_seq[0]
        # L2 optimization
        Z, obj = L2solver.solve_stft_regression_L2_tsparse(M,G_list, G_ind, X, Z, active_set,
                                 active_t_ind, coef_non_zero_mat,
                                 wsize=wsize, tstep = tstep, delta = delta_star,
                                maxit=maxit, tol = tol, 
                                Flag_backtrack = Flag_backtrack, L0 = L0, eta = eta,
                                Flag_verbose = Flag_verbose)
    # =========================================================================
    # reweighting should be done after the debiasing!!!
    # Reapply weights to have correct unit, To Be modifiled
    
    # it seems that in MNE0.11, source_weighting is the inverse of the original source weighting  
    # MNE 0.8 (verified in their 0.81 code "X /= source_weighting[active_set][:, None]")                    
    #Z /= source_weighting[active_set][:, None]
    # MNE 0.11
    Z = _reapply_source_weighting(Z, source_weighting, active_set, n_dip_per_pos)
    Z_full = np.zeros([active_set.sum(),p, n_freq, n_step], dtype = np.complex)
    Z_full[:,:,:,active_t_ind] = np.reshape(Z,[active_set.sum(), p,
                                              n_freq,active_t_ind.sum()])
    Z_full = np.reshape(Z_full, [active_set.sum(),-1])
 
#    do not compute stc_list   
#    tmin = evoked_list[0].times[0]
#    stc_tstep = 1.0 / info['sfreq']
#    stc_list = list()
#    for r in range(n_trials):
#        tmp_stc_data = np.zeros([active_set.sum(),n_times])
#        tmp_Z = np.zeros([active_set.sum(), n_coefs],dtype = np.complex)
#        for i in range(p):
#            tmp_Z += Z_full[:,i*n_coefs:(i+1)*n_coefs]* X[r,i]
#        # if it is a trial by_trial model, add the model for the single trial
#        tmp_stc_data = phiT(tmp_Z)              
#        tmp_stc = _make_sparse_stc(tmp_stc_data, active_set, fwd_list[G_ind[r]], tmin, stc_tstep)
#        stc_list.append(tmp_stc)
#    logger.info('[done]')                               
        
    return Z_full, active_set, active_t_ind, alpha_star, beta_star, gamma_star, delta_star, cv_MSE_lasso, cv_MSE_L2
Beispiel #16
0
def test_io_forward():
    """Test IO for forward solutions."""
    temp_dir = _TempDir()
    # do extensive tests with MEEG + grad
    n_channels, n_src = 366, 108
    fwd = read_forward_solution(fname_meeg_grad)
    assert (isinstance(fwd, Forward))
    fwd = read_forward_solution(fname_meeg_grad)
    fwd = convert_forward_solution(fwd, surf_ori=True)
    leadfield = fwd['sol']['data']
    assert_equal(leadfield.shape, (n_channels, n_src))
    assert_equal(len(fwd['sol']['row_names']), n_channels)
    fname_temp = op.join(temp_dir, 'test-fwd.fif')
    with pytest.warns(RuntimeWarning, match='stored on disk'):
        write_forward_solution(fname_temp, fwd, overwrite=True)

    fwd = read_forward_solution(fname_meeg_grad)
    fwd = convert_forward_solution(fwd, surf_ori=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read, surf_ori=True)
    leadfield = fwd_read['sol']['data']
    assert_equal(leadfield.shape, (n_channels, n_src))
    assert_equal(len(fwd_read['sol']['row_names']), n_channels)
    assert_equal(len(fwd_read['info']['chs']), n_channels)
    assert ('dev_head_t' in fwd_read['info'])
    assert ('mri_head_t' in fwd_read)
    assert_array_almost_equal(fwd['sol']['data'], fwd_read['sol']['data'])

    fwd = read_forward_solution(fname_meeg)
    fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                   use_cps=False)
    with pytest.warns(RuntimeWarning, match='stored on disk'):
        write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read, surf_ori=True,
                                        force_fixed=True, use_cps=False)
    assert (repr(fwd_read))
    assert (isinstance(fwd_read, Forward))
    assert (is_fixed_orient(fwd_read))
    compare_forwards(fwd, fwd_read)

    fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                   use_cps=True)
    leadfield = fwd['sol']['data']
    assert_equal(leadfield.shape, (n_channels, 1494 / 3))
    assert_equal(len(fwd['sol']['row_names']), n_channels)
    assert_equal(len(fwd['info']['chs']), n_channels)
    assert ('dev_head_t' in fwd['info'])
    assert ('mri_head_t' in fwd)
    assert (fwd['surf_ori'])
    with pytest.warns(RuntimeWarning, match='stored on disk'):
        write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read, surf_ori=True,
                                        force_fixed=True, use_cps=True)
    assert (repr(fwd_read))
    assert (isinstance(fwd_read, Forward))
    assert (is_fixed_orient(fwd_read))
    compare_forwards(fwd, fwd_read)

    fwd = read_forward_solution(fname_meeg_grad)
    fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                   use_cps=True)
    leadfield = fwd['sol']['data']
    assert_equal(leadfield.shape, (n_channels, n_src / 3))
    assert_equal(len(fwd['sol']['row_names']), n_channels)
    assert_equal(len(fwd['info']['chs']), n_channels)
    assert ('dev_head_t' in fwd['info'])
    assert ('mri_head_t' in fwd)
    assert (fwd['surf_ori'])
    with pytest.warns(RuntimeWarning, match='stored on disk'):
        write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    fwd_read = convert_forward_solution(fwd_read, surf_ori=True,
                                        force_fixed=True, use_cps=True)
    assert (repr(fwd_read))
    assert (isinstance(fwd_read, Forward))
    assert (is_fixed_orient(fwd_read))
    compare_forwards(fwd, fwd_read)

    # test warnings on bad filenames
    fwd = read_forward_solution(fname_meeg_grad)
    fwd_badname = op.join(temp_dir, 'test-bad-name.fif.gz')
    with pytest.warns(RuntimeWarning, match='end with'):
        write_forward_solution(fwd_badname, fwd)
    with pytest.warns(RuntimeWarning, match='end with'):
        read_forward_solution(fwd_badname)

    fwd = read_forward_solution(fname_meeg)
    write_forward_solution(fname_temp, fwd, overwrite=True)
    fwd_read = read_forward_solution(fname_temp)
    compare_forwards(fwd, fwd_read)