예제 #1
0
def generate_figure(R, MU, n_src, true_csd_xlims, total_ele,
                    method='cross-validation', Rs=None, lambdas=None,
                    noise=0):
    """
    Generates figure for targeted basis investigation.

    Parameters
    ----------
    R: float
        Thickness of the groundtruth source.
        Default: 0.2.
    MU: float
        Central position of Gaussian source
        Default: 0.25.
    nr_src: int
        Number of basis sources.
    true_csd_xlims: list
        Boundaries for ground truth space.
    total_ele: int
        Number of electrodes.
    save_path: string
        Directory.
    method: string
        Determines the method of regularization.
        Default: cross-validation.
    Rs: numpy 1D array
        Basis source parameter for crossvalidation.
        Default: None.
    lambdas: numpy 1D array
        Regularization parameter for crossvalidation.
        Default: None.
    noise: float
        Determines the level of noise in the data.
        Default: 0.

    Returns
    -------
    None
    """

    ele_lims = [0, 1.]
    csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(tb.csd_profile,
                                                            true_csd_xlims, R,
                                                            MU, total_ele,
                                                            ele_lims,
                                                            noise=noise)

    fig = plt.figure(figsize=(15, 12))
    widths = [1, 1, 1]
    heights = [1, 1, 1]
    gs = gridspec.GridSpec(3, 3, height_ratios=heights, width_ratios=widths,
                           hspace=0.45, wspace=0.3)

    ax = fig.add_subplot(gs[0, 0])
    xmin = 0
    xmax = 1
    ext_x = 0
    obj = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                            sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                            xmax=xmax, method=method, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj.values('CSD'), obj.estm_x, ele_pos=ele_pos,
                 title='Basis limits = [0, 1]', xlabel=False, ylabel=True,
                 letter='A')

    ax = fig.add_subplot(gs[0, 1])
    xmin = -0.5
    xmax = 1
    ext_x = -0.5
    obj = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                            sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                            xmax=xmax, method=method, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj.values('CSD'), obj.estm_x, ele_pos=ele_pos,
                 title='Basis limits = [0, 0.5]', xlabel=False, ylabel=False,
                 letter='B')

    ax = fig.add_subplot(gs[0, 2])
    xmin = 0
    xmax = 1.5
    ext_x = -0.5
    obj = tb.modified_bases(val, pots, ele_pos, n_src,  h=0.25,
                            sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                            xmax=xmax, method=method, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj.values('CSD'), obj.estm_x, ele_pos=ele_pos,
                 title='Basis limits = [0.5, 1]', xlabel=False, ylabel=False,
                 letter='C')

    ele_lims = [0, 0.5]
#    total_ele = 6
    csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(tb.csd_profile,
                                                            true_csd_xlims, R,
                                                            MU, total_ele,
                                                            ele_lims,
                                                            noise=noise)
    ax = fig.add_subplot(gs[1, 0])
    xmin = 0
    xmax = 1
    ext_x = 0
    obj = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                            sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                            xmax=xmax, method=method, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj.values('CSD'), obj.estm_x, ele_pos=ele_pos,
                 title=None, xlabel=False, ylabel=True, letter='D')

    ax = fig.add_subplot(gs[1, 1])
    xmin = -0.5
    xmax = 1
    ext_x = -0.5
    obj = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                            sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                            xmax=xmax, method=method, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj.values('CSD'), obj.estm_x, ele_pos=ele_pos,
                 title=None, xlabel=False, ylabel=False, letter='E')

    ax = fig.add_subplot(gs[1, 2])
    xmin = 0
    xmax = 1.5
    ext_x = -0.5
    obj = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                            sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                            xmax=xmax, method=method, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj.values('CSD'), obj.estm_x, ele_pos=ele_pos,
                 title=None, xlabel=False, ylabel=False, letter='F')

    ele_lims = [0.5, 1.]
    csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(tb.csd_profile,
                                                            true_csd_xlims, R,
                                                            MU, total_ele,
                                                            ele_lims,
                                                            noise=noise)
    ax = fig.add_subplot(gs[2, 0])
    xmin = 0
    xmax = 1
    ext_x = 0
    obj = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                            sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                            xmax=xmax, method=method, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj.values('CSD'), obj.estm_x, ele_pos=ele_pos,
                 title=None, xlabel=True, ylabel=True, letter='G')

    ax = fig.add_subplot(gs[2, 1])
    xmin = -0.5
    xmax = 1
    ext_x = -0.5
    obj = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                            sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                            xmax=xmax, method=method, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj.values('CSD'), obj.estm_x, ele_pos=ele_pos,
                 title=None, xlabel=True, ylabel=False, letter='H')

    ax = fig.add_subplot(gs[2, 2])
    xmin = 0
    xmax = 1.5
    ext_x = -0.5
    obj = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                            sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                            xmax=xmax, method=method, Rs=Rs, lambdas=lambdas)
    ax = make_subplot(ax, true_csd, obj.values('CSD'), obj.estm_x,
                      ele_pos=ele_pos, title=None, xlabel=True, ylabel=False,
                      letter='I')
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', ncol=3, frameon=False)

    fig.savefig(os.path.join('targeted_basis_' + method +
                             '_noise_' + str(noise) + '.png'), dpi=300)
    plt.show()
예제 #2
0
def generate_figure_CVLC(R, MU, n_src, true_csd_xlims, total_ele, save_path,
                         Rs=None, lambdas=None, noise=0):
    """
    Generates figure for targeted basis investigation including results from
    both cross-validation and L-curve.

    Parameters
    ----------
    R: float
        Thickness of the groundtruth source.
        Default: 0.2.
    MU: float
        Central position of Gaussian source
        Default: 0.25.
    nr_src: int
        Number of basis sources.
    true_csd_xlims: list
        Boundaries for ground truth space.
    total_ele: int
        Number of electrodes.
    save_path: string
        Directory.
    Rs: numpy 1D array
        Basis source parameter for crossvalidation.
        Default: None.
    lambdas: numpy 1D array
        Regularization parameter for crossvalidation.
        Default: None.
    noise: float
        Determines the level of noise in the data.
        Default: 0.

    Returns
    -------
    None
    """

    m_cv = 'cross-validation'
    m_lc = 'L-curve'
    method = 'CV_LC'
    ele_lims = [0, 1.]
    csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(tb.csd_profile,
                                                            true_csd_xlims, R,
                                                            MU, total_ele,
                                                            ele_lims,
                                                            noise=noise)

    fig = plt.figure(figsize=(15, 12))
    widths = [1, 1, 1]
    heights = [1, 1, 1]
    gs = gridspec.GridSpec(3, 3, height_ratios=heights, width_ratios=widths,
                           hspace=0.45, wspace=0.3)

    ax = fig.add_subplot(gs[0, 0])
    xmin = 0
    xmax = 1
    ext_x = 0
    obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas)
    obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x,
                 ele_pos=ele_pos, title='Basis limits = [0, 1]', xlabel=False,
                 ylabel=True, letter='A', est_csd_LC=obj_LC.values('CSD'))

    ax = fig.add_subplot(gs[0, 1])
    xmin = -0.5
    xmax = 1
    ext_x = -0.5
    obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas)
    obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x,
                 ele_pos=ele_pos, title='Basis limits = [0, 0.5]',
                 xlabel=False, ylabel=False, letter='B',
                 est_csd_LC=obj_LC.values('CSD'))

    ax = fig.add_subplot(gs[0, 2])
    xmin = 0
    xmax = 1.5
    ext_x = -0.5
    obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas)
    obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x,
                 ele_pos=ele_pos, title='Basis limits = [0.5, 1]',
                 xlabel=False, ylabel=False, letter='C',
                 est_csd_LC=obj_LC.values('CSD'))

    ele_lims = [0, 0.5]
#    total_ele = 6
    csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(tb.csd_profile,
                                                            true_csd_xlims, R,
                                                            MU, total_ele,
                                                            ele_lims,
                                                            noise=noise)
    ax = fig.add_subplot(gs[1, 0])
    xmin = 0
    xmax = 1
    ext_x = 0
    obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas)
    obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x,
                 ele_pos=ele_pos, title=None, xlabel=False, ylabel=True,
                 letter='D', est_csd_LC=obj_LC.values('CSD'))

    ax = fig.add_subplot(gs[1, 1])
    xmin = -0.5
    xmax = 1
    ext_x = -0.5
    obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas)
    obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x,
                 ele_pos=ele_pos, title=None, xlabel=False, ylabel=False,
                 letter='E', est_csd_LC=obj_LC.values('CSD'))

    ax = fig.add_subplot(gs[1, 2])
    xmin = 0
    xmax = 1.5
    ext_x = -0.5
    obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas)
    obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x,
                 ele_pos=ele_pos, title=None, xlabel=False, ylabel=False,
                 letter='F', est_csd_LC=obj_LC.values('CSD'))

    ele_lims = [0.5, 1.]
    csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(tb.csd_profile,
                                                            true_csd_xlims, R,
                                                            MU, total_ele,
                                                            ele_lims,
                                                            noise=noise)
    ax = fig.add_subplot(gs[2, 0])
    xmin = 0
    xmax = 1
    ext_x = 0
    obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas)
    obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x,
                 ele_pos=ele_pos, title=None, xlabel=True, ylabel=True,
                 letter='G', est_csd_LC=obj_LC.values('CSD'))

    ax = fig.add_subplot(gs[2, 1])
    xmin = -0.5
    xmax = 1
    ext_x = -0.5
    obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas)
    obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas)
    make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x,
                 ele_pos=ele_pos, title=None, xlabel=True, ylabel=False,
                 letter='H', est_csd_LC=obj_LC.values('CSD'))

    ax = fig.add_subplot(gs[2, 2])
    xmin = 0
    xmax = 1.5
    ext_x = -0.5
    obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas)
    obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25,
                               sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin,
                               xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas)
    ax = make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x,
                      ele_pos=ele_pos, title=None, xlabel=True, ylabel=False,
                      letter='I', est_csd_LC=obj_LC.values('CSD'))
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', ncol=4, frameon=False)

    fig.savefig(os.path.join(save_path, 'targeted_basis_' + method +
                             '_noise_' + str(noise) + '.png'), dpi=300)
    plt.show()
예제 #3
0
def generate_figure(csd_profile,
                    R,
                    MU,
                    true_csd_xlims,
                    total_ele,
                    ele_lims,
                    save_path,
                    noise=0):
    """
    Generates figure for spectral structure decomposition.

    Parameters
    ----------
    csd_profile: function
        Function to produce csd profile.
    R: float
        Thickness of the groundtruth source.
        Default: 0.2.
    MU: float
        Central position of Gaussian source
        Default: 0.25.
    true_csd_xlims: list
        Boundaries for ground truth space.
    total_ele: int
        Number of electrodes.
    ele_lims: list
        Electrodes limits.
    save_path: string
        Directory.
    method: string
        Determines the method of regularization.
        Default: cross-validation.
    Rs: numpy 1D array
        Basis source parameter for crossvalidation.
        Default: None.
    lambdas: numpy 1D array
        Regularization parameter for crossvalidation.
        Default: None.
    noise: float
        Determines the level of noise in the data.
        Default: 0.

    Returns
    -------
    None
    """
    csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(csd_profile,
                                                            true_csd_xlims,
                                                            R,
                                                            MU,
                                                            total_ele,
                                                            ele_lims,
                                                            noise=noise)

    n_src_M = [512]
    R_init = [0.05, 0.1, 0.2, 0.4, 0.8]
    OBJ_M, eigenval_M, eigenvec_M = stability_M(csd_profile,
                                                n_src_M,
                                                ele_lims,
                                                true_csd_xlims,
                                                total_ele,
                                                ele_pos,
                                                pots,
                                                R_init=R_init)

    plt_cord = [(2, 0), (2, 2), (2, 4), (3, 0), (3, 2), (3, 4), (4, 0), (4, 2),
                (4, 4), (5, 0), (5, 2), (5, 4)]

    letters = ['C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N']

    BLACK = _html(0, 0, 0)
    ORANGE = _html(230, 159, 0)
    SKY_BLUE = _html(86, 180, 233)
    GREEN = _html(0, 158, 115)
    YELLOW = _html(240, 228, 66)
    BLUE = _html(0, 114, 178)
    VERMILION = _html(213, 94, 0)
    PURPLE = _html(204, 121, 167)
    colors = [BLUE, ORANGE, GREEN, PURPLE, VERMILION, SKY_BLUE, YELLOW, BLACK]

    fig = plt.figure(figsize=(18, 16))
    heights = [4, 0.3, 1, 1, 1, 1]
    markers = ['^', '.', '*', 'x', ',']
    linestyles = ['-', '-', '-', '-', '-']

    gs = gridspec.GridSpec(6, 6, height_ratios=heights, hspace=0.3, wspace=0.6)

    ax = fig.add_subplot(gs[0, :3])
    for indx, i in enumerate(R_init):
        ax.plot(np.arange(1, total_ele + 1),
                eigenval_M[indx],
                linestyle=linestyles[indx],
                color=colors[indx],
                marker=markers[indx],
                label='R=' + str(R_init[indx]),
                markersize=10)
    ht, lh = ax.get_legend_handles_labels()
    set_axis(ax, -0.05, 1.05, letter='A')
    ax.set_xlabel('Number of components')
    ax.set_ylabel('Eigenvalues')
    ax.set_yscale('log')
    ax.set_ylim([1e-6, 1])
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    ax = fig.add_subplot(gs[0, 3:])
    ax.plot(R_init,
            eigenval_M[:, 0],
            marker='s',
            color='k',
            markersize=5,
            linestyle=' ')
    set_axis(ax, -0.05, 1.05, letter='B')
    ax.set_xlabel('R')
    ax.set_ylabel('Eigenvalues')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    for i in range(OBJ_M[0].k_interp_cross.shape[1]):
        ax = fig.add_subplot(gs[plt_cord[i][0],
                                plt_cord[i][1]:plt_cord[i][1] + 2])
        for idx, j in enumerate(R_init):
            a = np.dot(OBJ_M[idx].k_interp_cross, eigenvec_M[idx, :, i])
            a = a / np.linalg.norm(a)
            ax.plot(np.linspace(0, 1, 100),
                    a,
                    linestyle=linestyles[idx],
                    color=colors[idx],
                    label='M=' + str(R_init[idx]),
                    lw=2)
            ax.text(0.5,
                    1.,
                    r"$\tilde{K}\cdot{v_{{%(i)d}}}$" % {'i': i + 1},
                    horizontalalignment='center',
                    transform=ax.transAxes,
                    fontsize=20)
            set_axis(ax, -0.10, 1.1, letter=letters[i])
            if i < 9:
                ax.get_xaxis().set_visible(False)
                ax.spines['bottom'].set_visible(False)
            else:
                ax.set_xlabel('Depth ($mm$)')
            if i % 3 == 0:
                ax.set_ylabel('CSD ($mA/mm$)')
                ax.yaxis.set_label_coords(-0.18, 0.5)
            ax.ticklabel_format(style='sci', axis='y', scilimits=((0.0, 0.0)))
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
    fig.legend(ht, lh, loc='lower center', ncol=5, frameon=False)
    fig.savefig(os.path.join(
        save_path, 'vectors_' + '_noise_' + str(noise) + 'R0_8' + '.png'),
                dpi=300)
    plt.show()
def generate_figure_projection(R, MU, n_src, true_csd_xlims, total_ele,
                    method='cross-validation', Rs=None, lambdas=None,
                    noise=0):
    """
    Generates figure for targeted basis investigation.

    Parameters
    ----------
    R: float
        Thickness of the groundtruth source.
        Default: 0.2.
    MU: float
        Central position of Gaussian source
        Default: 0.25.
    nr_src: int
        Number of basis sources.
    true_csd_xlims: list
        Boundaries for ground truth space.
    total_ele: int
        Number of electrodes.
    save_path: string
        Directory.
    method: string
        Determines the method of regularization.
        Default: cross-validation.
    Rs: numpy 1D array
        Basis source parameter for crossvalidation.
        Default: None.
    lambdas: numpy 1D array
        Regularization parameter for crossvalidation.
        Default: None.
    noise: float
        Determines the level of noise in the data.
        Default: 0.

    Returns
    -------
    None
    """
    true_csd_projection = tb.csd_profile(np.linspace(0, 1, 100), [R, MU])
    ele_lims = [0, 0.5]
    csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(tb.csd_profile,
                                                            true_csd_xlims, R,
                                                            MU, total_ele,
                                                            ele_lims,
                                                            noise=noise)

    fig = plt.figure(figsize=(15, 12))
    widths = [1, 1, 1, 1]
    heights = [1, 1, 1]
    gs = gridspec.GridSpec(3, 4, height_ratios=heights, width_ratios=widths,
                           hspace=0.45, wspace=0.3)


    #estimation [0, 1], sources[0, 0.5]
    own_est = np.linspace(0, 1, 100)
    own_src = np.linspace(0, 0.5, n_src)
    obj = modified_bases(val, pots, ele_pos, n_src, own_est, own_src, h=0.25,
                            sigma=0.3, method=method, Rs=Rs, lambdas=lambdas)
    
    ax = fig.add_subplot(gs[0, 0])
    fb.make_subplot(ax, true_csd, obj.values('CSD'), obj.estm_x, ele_pos=ele_pos,
                 title="CSD", xlabel=False, ylabel=True, letter='A')
    
    eigensources = calculate_eigensources(obj)
    projection = csd_into_eigensource_projection(true_csd_projection, eigensources)
    anihilator = true_csd_projection - projection
    
    ax = fig.add_subplot(gs[0, 1])
    plot_projection(ax, projection, obj.estm_x, ele_pos=ele_pos,
                 title='Projection', xlabel=False, ylabel=False, letter='B')
    
    ax = fig.add_subplot(gs[0, 2])
    plot_projection(ax, anihilator, obj.estm_x, ele_pos=ele_pos,
                 title='Annihilated', xlabel=False, ylabel=False, letter='C', c='r', label='Annihilated')
    
    ax = fig.add_subplot(gs[0, 3])
    make_subplot(ax, eigensources, obj.estm_x, ele_pos=ele_pos,
                 title='Eigensources', xlabel=False, ylabel=False, letter='D')


    ele_lims = [0.5, 1]
    csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(tb.csd_profile,
                                                            true_csd_xlims, R,
                                                            MU, total_ele,
                                                            ele_lims,
                                                            noise=noise)
    
    #estimation [0, 1], sources[0, 0.5]
    own_est = np.linspace(0., 1., 100)
    own_src = np.linspace(0., 0.5, n_src)
    obj = modified_bases(val, pots, ele_pos, n_src, own_est, own_src, h=0.25,
                            sigma=0.3, method=method, Rs=Rs, lambdas=lambdas)
    
    ax = fig.add_subplot(gs[1, 0])
    fb.make_subplot(ax, true_csd, obj.values('CSD'), obj.estm_x, ele_pos=ele_pos,
                 title=None, xlabel=False, ylabel=True, letter='E')
    
    eigensources = calculate_eigensources(obj)
    projection = csd_into_eigensource_projection(true_csd_projection, eigensources)
    anihilator = true_csd_projection - projection
    
    ax = fig.add_subplot(gs[1, 1])
    plot_projection(ax, projection, obj.estm_x, ele_pos=ele_pos,
                 title=None, xlabel=False, ylabel=False, letter='F')
    
    ax = fig.add_subplot(gs[1, 2])
    plot_projection(ax, anihilator, obj.estm_x, ele_pos=ele_pos,
                 title=None, xlabel=False, ylabel=False, letter='G', c='r')
    
    ax = fig.add_subplot(gs[1, 3])
    make_subplot(ax, eigensources, obj.estm_x, ele_pos=ele_pos,
                 title=None, xlabel=False, ylabel=False, letter='H')

    
    #estimation [0, 1], sources[0.5, 1]
    own_est3 = np.linspace(0., 1., 100)
    own_src3 = np.linspace(0.5, 1., n_src)
    obj = modified_bases(val, pots, ele_pos, n_src, own_est3, own_src3, h=0.25,
                            sigma=0.3, method=method, Rs=Rs, lambdas=lambdas)
    ax1 = fig.add_subplot(gs[2, 0])
    ax1 = fb.make_subplot(ax1, true_csd, obj.values('CSD'), obj.estm_x, ele_pos=ele_pos,
                 title=None, xlabel=True, ylabel=True, letter='I')
    
    eigensources = calculate_eigensources(obj)
    projection = csd_into_eigensource_projection(true_csd_projection, eigensources)
    anihilator = true_csd_projection - projection
    
    ax2 = fig.add_subplot(gs[2, 1])
    ax2 = plot_projection(ax2, projection, obj.estm_x, ele_pos=ele_pos,
                 title=None, xlabel=True, ylabel=False, letter='J')
    
    ax3 = fig.add_subplot(gs[2, 2])
    ax3 = plot_projection(ax3, anihilator, obj.estm_x, ele_pos=ele_pos,
                 title=None, xlabel=True, ylabel=False, letter='K', c='r', label='Annihilated')
    
    ax4 = fig.add_subplot(gs[2, 3])
    ax4 = make_subplot(ax4, eigensources, obj.estm_x,
                      ele_pos=ele_pos, title=None, xlabel=True, ylabel=False,
                      letter='L')
    
    handles, labels = [(a + b +c + d) for a, b, c, d in zip(ax1.get_legend_handles_labels(), ax2.get_legend_handles_labels(),
                                                            ax3.get_legend_handles_labels(), ax4.get_legend_handles_labels())]
    
    fig.legend(handles, labels, loc='lower center', ncol=6, frameon=False)

    fig.savefig(os.path.join('Projection_targeted_basis_noise_' + str(noise) + 'normalized2_oKCSD1D.png'), dpi=300)
    plt.show()
def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
                    noise=0, R_init=0.23):
    """
    Generates figure for spectral structure decomposition.

    Parameters
    ----------
    csd_profile: function
        Function to produce csd profile.
    R: float
        Thickness of the groundtruth source.
        Default: 0.2.
    MU: float
        Central position of Gaussian source
        Default: 0.25.
    true_csd_xlims: list
        Boundaries for ground truth space.
    total_ele: int
        Number of electrodes.
    ele_lims: list
        Electrodes limits.
    noise: float
        Determines the level of noise in the data.
        Default: 0.
    R_init: float
        Initial value of R parameter - width of basis source
        Default: 0.23.

    Returns
    -------
    None
    """
    csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(csd_profile,
                                                            true_csd_xlims,
                                                            R, MU, total_ele,
                                                            ele_lims,
                                                            noise=noise)
    print('csd', true_csd.shape)
    n_src_M = [2, 4, 8, 16, 32, 64, 128, 256, 512]
    OBJ_M, eigenval_M, eigenvec_M = stability_M(n_src_M,
                                                total_ele, ele_pos, pots,
                                                R_init=R_init)
    # print('eigenvector', eigenvec_M[0].shape)
    # eigensources = calculate_eigensources(OBJ_M[0].k_interp_cross, eigenvec_M[0])
    projection_M = []
    fig = plt.figure(figsize=(10, 6))
    ax = fig.add_subplot(131)
    ax.set_title('Projection')
    for i in range(len(n_src_M)):
        projection = calculate_projection(true_csd, OBJ_M[i], eigenvec_M[i])
        projection_M.append(projection)
        ax.plot(np.linspace(0, 1, 100), projection, label='M='+str(n_src_M[i]))
        ax.set_xlabel('Depth (mm)')
        ax.set_ylabel('CSD (mA/mm)')
    set_axis(ax, -0.05, 1.05, letter='A')
    
    ax = fig.add_subplot(132)
    ax.set_title('Error')
    for i in range(len(n_src_M)):
        err = calculate_diff(true_csd, projection_M[i])
        ax.plot(np.linspace(0, 1, 100), err, label='M='+str(n_src_M[i]))
        ax.set_xlabel('Depth (mm)')
    set_axis(ax, -0.05, 1.05, letter='B')
    
    ax = fig.add_subplot(133)
    ax.set_title('Anihilator')
    for i in range(len(n_src_M)):
        err = calculate_diff(true_csd.reshape(true_csd.shape[0], 1), OBJ_M[i].values('CSD'))
        ax.plot(np.linspace(0, 1, 100), err, label='M='+str(n_src_M[i]))
        ax.set_xlabel('Depth (mm)')
    set_axis(ax, -0.05, 1.05, letter='C')
    plt.tight_layout()
    handles, labels = ax.get_legend_handles_labels()
    lgd = fig.legend(handles, labels, loc='lower center', ncol=9, frameon=False, bbox_to_anchor=(0.5, -0.02), fontsize=10)
    fig.savefig(os.path.join('projections_different_M' + '.png'), dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')
    
    print(OBJ_M[i].values('CSD').shape)

#     BLACK = _html(0, 0, 0)
#     ORANGE = _html(230, 159, 0)
#     SKY_BLUE = _html(86, 180, 233)
#     GREEN = _html(0, 158, 115)
#     YELLOW = _html(240, 228, 66)
#     BLUE = _html(0, 114, 178)
#     VERMILION = _html(213, 94, 0)
#     PURPLE = _html(204, 121, 167)
#     colors = [BLUE, ORANGE, GREEN, PURPLE, VERMILION, SKY_BLUE, YELLOW, BLACK]

#     fig = plt.figure(figsize=(18, 16))
# #    heights = [1, 1, 1, 0.2, 1, 1, 1, 1]
#     heights = [4, 0.3, 1, 1, 1, 1]
#     markers = ['^', '.', '*', 'x', ',']
# #    linestyles = [':', '--', '-.', '-']
#     linestyles = ['-', '-', '-', '-']

    return true_csd, projection_M
예제 #6
0
def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
                    save_path, method='cross-validation', Rs=None,
                    lambdas=None, noise=0):
    """
    Generates figure for spectral structure decomposition.

    Parameters
    ----------
    csd_profile: function
        Function to produce csd profile.
    R: float
        Thickness of the groundtruth source.
        Default: 0.2.
    MU: float
        Central position of Gaussian source
        Default: 0.25.
    true_csd_xlims: list
        Boundaries for ground truth space.
    total_ele: int
        Number of electrodes.
    ele_lims: list
        Electrodes limits.
    save_path: string
        Directory.
    method: string
        Determines the method of regularization.
        Default: cross-validation.
    Rs: numpy 1D array
        Basis source parameter for crossvalidation.
        Default: None.
    lambdas: numpy 1D array
        Regularization parameter for crossvalidation.
        Default: None.
    noise: float
        Determines the level of noise in the data.
        Default: 0.

    Returns
    -------
    None
    """
    csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(csd_profile,
                                                            true_csd_xlims,
                                                            R, MU, total_ele,
                                                            ele_lims,
                                                            noise=noise)

    n_src_M = [2, 4, 8, 16, 32, 64, 128, 256, 512]
    OBJ_M, eigenval_M, eigenvec_M = stability_M(csd_profile, n_src_M,
                                                ele_lims, true_csd_xlims,
                                                total_ele, ele_pos, pots,
                                                method=method, Rs=Rs,
                                                lambdas=lambdas)

    plt_cord = [(2, 0), (2, 2), (2, 4),
                (3, 0), (3, 2), (3, 4),
                (4, 0), (4, 2), (4, 4),
                (5, 0), (5, 2), (5, 4)]


    letters = ['C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'O']

    BLACK = _html(0, 0, 0)
    ORANGE = _html(230, 159, 0)
    SKY_BLUE = _html(86, 180, 233)
    GREEN = _html(0, 158, 115)
    YELLOW = _html(240, 228, 66)
    BLUE = _html(0, 114, 178)
    VERMILION = _html(213, 94, 0)
    PURPLE = _html(204, 121, 167)
    colors = [BLUE, ORANGE, GREEN, PURPLE, VERMILION, SKY_BLUE, YELLOW, BLACK]

    fig = plt.figure(figsize=(18, 16))
#    heights = [1, 1, 1, 0.2, 1, 1, 1, 1]
    heights = [4, 0.3, 1, 1, 1, 1]
    markers = ['^', '.', '*', 'x', ',']
#    linestyles = [':', '--', '-.', '-']
    linestyles = ['-', '-', '-', '-']
    src_idx = [0, 2, 3, 8]

    gs = gridspec.GridSpec(6, 6, height_ratios=heights, hspace=0.3, wspace=0.6)

    ax = fig.add_subplot(gs[0, :3])
    for indx, i in enumerate(src_idx):
        ax.plot(np.arange(1, total_ele + 1), eigenval_M[i],
                linestyle=linestyles[indx], color=colors[indx],
                marker=markers[indx], label='M='+str(n_src_M[i]),
                markersize=10)
#    ax.set_title(' ', fontsize=12)
    ht, lh = ax.get_legend_handles_labels()
    set_axis(ax, -0.05, 1.05, letter='A')
#    ax.legend(loc='lower left')
    ax.set_xlabel('Number of components')
    ax.set_ylabel('Eigenvalues')
    ax.set_yscale('log')
    ax.set_ylim([1e-6, 1])
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    ax = fig.add_subplot(gs[0, 3:])
    ax.plot(n_src_M, eigenval_M[:, 0], marker='s', color='k', markersize=5,
            linestyle=' ')
    #ax.set_title(' ', fontsize=12)
    set_axis(ax, -0.05, 1.05, letter='B')
    ax.set_xlabel('Number of basis sources')
    ax.set_xscale('log')
    ax.set_ylabel('Eigenvalues')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    for i in range(OBJ_M[0].k_interp_cross.shape[1]):
        ax = fig.add_subplot(gs[plt_cord[i][0],
                                plt_cord[i][1]:plt_cord[i][1]+2])
        for idx, j in enumerate(src_idx):
            ax.plot(np.linspace(0, 1, 100), np.dot(OBJ_M[j].k_interp_cross,
                    eigenvec_M[j, :, i]),
                    linestyle=linestyles[idx], color=colors[idx],
                    label='M='+str(n_src_M[j]), lw=2)
            #ax.set_title(r"$\tilde{K}*v_{{%(i)d}}$" % {'i': i+1})
            ax.text(0.5, 1., r"$\tilde{K}*v_{{%(i)d}}$" % {'i': i+1},
                    horizontalalignment='center', transform=ax.transAxes, fontsize=20)
#            ax.locator_params(axis='y', nbins=3)

#            ax.set_xlabel('Depth (mm)', fontsize=12)
#            ax.set_ylabel('CSD (mA/mm)', fontsize=12)
            set_axis(ax, -0.10, 1.1, letter=letters[i])
            if i < 9:
                ax.get_xaxis().set_visible(False)
                ax.spines['bottom'].set_visible(False)
            else:
                ax.set_xlabel('Depth ($mm$)')
            if i % 3 == 0:
                ax.set_ylabel('CSD ($mA/mm$)')
                ax.yaxis.set_label_coords(-0.18, 0.5)
#            ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
#            ax.tick_params(direction='out', pad=10)
#            ax.yaxis.get_major_formatter(FormatStrFormatter('%.2f'))
            ax.ticklabel_format(style='sci', axis='y', scilimits=((0.0, 0.0)))
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
#     ht, lh = ax.get_legend_handles_labels()

#     ax = fig.add_subplot(gs[3, :])
#     ax.legend(ht, lh,  fancybox=False, shadow=False, ncol=len(src_idx),
#               loc='upper center', frameon=False, bbox_to_anchor=(0.5, 0.0))
#     ax.axis('off')

#    plt.tight_layout()
    fig.legend(ht, lh, loc='lower center', ncol=5, frameon=False)
    fig.savefig(os.path.join(save_path, 'vectors_' + method +
                             '_noise_' + str(noise) + '.png'), dpi=300)

    plt.show()