def test_simulated_image_welch(length_simul=514,
                title_prefix='test_simulated_image', figure='smiley', lbda=1, size=10, mask=True):
    if figure=='smiley':
        s = opas.smiley(size)
    else:
        s = opas.square2(size)

    if mask:
        mask = s > 0.1
    else:
        mask = np.ones(s.shape, dtype=bool)

    signal = opas.get_simulation_from_picture(s, lsimul=length_simul)[mask]

    estimate, regularized = welch_estimator.welch_tv_estimator(signal, mask, lbda)
    
    plt.imshow(_unmask(estimate,mask))
    plt.colorbar()
    plt.figure()
    
    plt.imshow(_unmask(regularized,mask))
    plt.colorbar()
    plt.show()
def test_simulated_image(j1=3, j2=6, wtype=1, length_simul=514,
                title_prefix='test_simulated_image', figure='smiley', size=10, mask=True):
    if figure=='smiley':
        s = opas.smiley(size)
    else:
        s = opas.square2(size)

    if mask:
        mask = s > 0.1
    else:
        mask = np.ones(s.shape, dtype=bool)

    signal = opas.get_simulation_from_picture(s, lsimul=length_simul)
    signalshape = signal.shape
    shape = signalshape[:- 1]
    sig = np.reshape(signal, (signalshape[0] * signalshape[1], signalshape[2]))
    N = sig.shape[0]

    estimate = np.zeros(N)
    aest = np.zeros(N)
    simulation = np.cumsum(sig, axis=1)

    #######################################################################

    dico = wtspecq_statlog32(simulation, 2, 1, np.array(2),
                                int(np.log2(length_simul)), 0, 0)
    Elog = dico['Elogmuqj'][:, 0]
    Varlog = dico['Varlogmuqj'][:, 0]
    nj = dico['nj']

    for j in np.arange(0, N):
        sortie = regrespond_det2(Elog[j], Varlog[j], 2, nj, j1, j2, wtype)
        estimate[j] = sortie['Zeta'] / 2. #normalement Zeta
        aest[j]  = sortie['aest']

    #######################################################################

    f = lambda x, lbda: penalized.loss_l2_penalization_on_grad(x, aest[mask.ravel()],
                        Elog[mask.ravel()], Varlog[mask.ravel()], nj, j1, j2, mask, l=lbda)
    #We set epsilon to 0
    g = lambda x, lbda: penalized.grad_loss_l2_penalization_on_grad(x, aest[mask.ravel()],
                        Elog[mask.ravel()], Varlog[mask.ravel()], nj, j1, j2, mask, l=lbda)

    l2_title = title_prefix + 'loss_l2_penalisation_on_grad'

    fg = lambda x, lbda, **kwargs: (f(x, lbda), g(x, lbda))
    #For each lambda we use blgs algorithm to find the minimum
    # We start from the
    l2_algo = lambda lbda: fmin_l_bfgs_b(lambda x: fg(x, lbda), estimate[mask.ravel()])

    #######################################################################

    j22 = np.min((j2, len(nj)))
    j1j2 = np.arange(j1 - 1, j22)
    njj = nj[j1j2]
    N = sum(njj)
    wvarjj = njj / N
    lipschitz_constant =  np.sum(8 * ((j1j2 + 1) ** 2) * wvarjj)
    l1_ratio = 0
    tv_algo = lambda lbda: penalized.mtvsolver(estimate[mask.ravel()], aest[mask.ravel()],
                                        Elog[mask.ravel()], Varlog[mask.ravel()],
                                        nj, j1, j2,mask,
                                        lipschitz_constant=lipschitz_constant,
                                        l1_ratio = l1_ratio, l=lbda)
    tv_title = title_prefix + 'wetvp'

    #######################################################################

    lmax = 15
    l2_minimizor = np.zeros((lmax,) + s.shape)
    l2_rmse = np.zeros(lmax)
    tv_minimizor = np.zeros((lmax,) + s.shape)
    tv_rmse = np.zeros(lmax)

    r = np.arange(lmax)
    lbda = np.array((0,) + tuple(1.5 ** r[:- 1]))

    for idx in r:
        algo_min = l2_algo(lbda[idx])
        l2_minimizor[idx] = _unmask(algo_min[0], mask)
        l2_rmse[idx] = np.sqrt(np.mean((l2_minimizor[idx] - s) ** 2))

        if idx == 0:
            l2_min_rmse = l2_rmse[idx]
            l2_min_rmse_idx = 0
        else:
            if l2_min_rmse > l2_rmse[idx]:
                l2_min_rmse = l2_rmse[idx]
                l2_min_rmse_idx = idx

        algo_min = tv_algo(lbda[idx])
        tv_minimizor[idx] = _unmask(algo_min[0], mask)
        tv_rmse[idx] = np.sqrt(np.mean((tv_minimizor[idx] - s) ** 2))

        if idx == 0:
            tv_min_rmse = l2_rmse[idx]
            tv_min_rmse_idx = 0
        else:
            if tv_min_rmse > tv_rmse[idx]:
                tv_min_rmse = tv_rmse[idx]
                tv_min_rmse_idx = idx

    #######################################################################

    for minimizor_idx, (title, minimizor) in enumerate(zip([tv_title, l2_title],
                                    [tv_minimizor, l2_minimizor])):

        plt.figure(1)
        plt.title(title)

        fig, axes = plt.subplots(nrows=3, ncols=int(ceil(lmax / 3.)))
        fig2, axes2 = plt.subplots(nrows=3, ncols=int(ceil(lmax / 3.)))
        for idx, (dat, ax, ax2) in enumerate(zip(minimizor, axes.flat, axes2.flat)):
            im = ax.imshow(dat, norm=Normalize(vmin=np.min(minimizor),
                                            vmax=np.max(minimizor)), interpolation='nearest')
            ax.axis('off')
            ax.set_title("$\lambda$ = %.1f " % (lbda[idx]))

            im2 = ax2.imshow(dat, interpolation='nearest')
            ax2.axis('off')
            ax2.set_title("$\lambda$ = %.1f " % (lbda[idx]))

        cax = fig.add_axes([0.91, 0.1, 0.028, 0.8])
        fig.colorbar(im, cax=cax)
        cax2 = fig2.add_axes([0.91, 0.1, 0.028, 0.8])
        fig2.colorbar(im2, cax=cax2)
        fig.savefig('/volatile/hubert/beamer/graphics/juillet2015/' + title + '_graph.pdf')

    fig = plt.figure()
    plt.title('l2 minimizor of rmse $\lambda$ = %.1f ' % (lbda[l2_min_rmse_idx]))
    im = plt.imshow(l2_minimizor[l2_min_rmse_idx], norm=Normalize(vmin=np.min(l2_minimizor),
                                        vmax=np.max(l2_minimizor)), interpolation='nearest')
    plt.axis('off')
    cax = fig.add_axes([0.91, 0.1, 0.028, 0.8])
    fig.colorbar(im, cax=cax)
    fig.savefig('/volatile/hubert/beamer/graphics/juillet2015/'+l2_title+'minimizor.pdf')

    fig = plt.figure()
    plt.title('tv minimizor of rmse $\lambda$ = %.1f ' % (lbda[tv_min_rmse_idx]))
    im = plt.imshow(tv_minimizor[tv_min_rmse_idx], norm=Normalize(vmin=np.min(tv_minimizor),
                                        vmax=np.max(tv_minimizor)), interpolation='nearest')
    plt.axis('off')
    cax = fig.add_axes([0.91, 0.1, 0.028, 0.8])
    fig.colorbar(im, cax=cax)
    fig.savefig('/volatile/hubert/beamer/graphics/juillet2015/'+tv_title+'minimizor.pdf')

    ##image of difference
    l2_diff = l2_minimizor[l2_min_rmse_idx]-s
    tv_diff = tv_minimizor[tv_min_rmse_idx]-s
    normalize_vmax = np.max((l2_diff, tv_diff))
    normalize_vmin = np.min((l2_diff, tv_diff))

    fig = plt.figure()
    plt.title('l2 minimizor of rmse, difference with original image $\lambda$ = %.1f ' % (lbda[l2_min_rmse_idx]))
    im = plt.imshow(l2_diff, norm=Normalize(vmin=normalize_vmin,
                                        vmax=normalize_vmax), interpolation='nearest')
    plt.axis('off')
    cax = fig.add_axes([0.91, 0.1, 0.028, 0.8])
    fig.colorbar(im, cax=cax)
    fig.savefig('/volatile/hubert/beamer/graphics/juillet2015/'+l2_title+'minimizordiff.pdf')

    fig = plt.figure()
    plt.title('tv minimizor of rmse, difference with original image $\lambda$ = %.1f ' % (lbda[tv_min_rmse_idx]))
    im = plt.imshow(tv_diff, norm=Normalize(vmin=normalize_vmin,
                                        vmax=normalize_vmax), interpolation='nearest')
    plt.axis('off')
    cax = fig.add_axes([0.91, 0.1, 0.028, 0.8])
    fig.colorbar(im, cax=cax)
    fig.savefig('/volatile/hubert/beamer/graphics/juillet2015/'+tv_title+'minimizordiff.pdf')

    fig3 = plt.figure()
    plt.plot(lbda, l2_rmse, 'r', label='l2 rmse')
    plt.plot(lbda, tv_rmse, 'b', label='tv rmse')
    plt.axvline(lbda[l2_min_rmse_idx], color='r')
    plt.axvline(lbda[tv_min_rmse_idx], color='b')
    plt.ylabel('rmse')
    plt.xlabel('lambda')
    plt.legend()

    fig3.savefig('/volatile/hubert/beamer/graphics/juillet2015/' + title + '_rmse.pdf')
    print title

    plt.show()
def test_simulated_image2(j1=3, j2=6, wtype=1, length_simul=514,
                title_prefix='test_simulated_image', figure='smiley', size=10, mask=True):
    if figure=='smiley':
        s = opas.smiley(size)
    else:
        s = opas.square2(size)

    if mask:
        mask = s > 0.1
    else:
        mask = np.ones(s.shape, dtype=bool)

    signal = opas.get_simulation_from_picture(s, lsimul=514)
    signalshape = signal.shape
    shape = signalshape[:- 1]
    sig514 = signal[mask]

    signal = opas.get_simulation_from_picture(s, lsimul=4096)
    signalshape = signal.shape
    shape = signalshape[:- 1]
    sig4096 = signal[mask]
    N = sig4096.shape[0]

    simulation514 = np.cumsum(sig514, axis=1)
    simulation4096 = np.cumsum(sig4096, axis=1)
    #######################################################################

    dico = hdw_p(simulation514, 2, 1, np.array(2),
                                int(np.log2(length_simul)), 0, wtype, j1, j2, 0)

    estimate514 = dico['Zeta'] / 2. #normalement Zeta

    #######################################################################

    dico = hdw_p(simulation4096, 2, 1, np.array(2),
                                int(np.log2(length_simul)), 0, wtype, j1, j2, 0)

    estimate4096 = dico['Zeta'] / 2. #normalement Zeta

    #######################################################################

    fig = plt.figure(1)

    im = plt.imshow(_unmask(estimate514, mask), norm=Normalize(vmin=np.min(estimate514),
                                    vmax=np.max(estimate514)),
                                    interpolation='nearest')
    plt.axis('off')

    fig2 = plt.figure(2)

    im2 = plt.imshow(_unmask(estimate4096, mask),norm=Normalize(vmin=np.min(estimate514),
                                    vmax=np.max(estimate514)),
                                    interpolation='nearest')
    plt.axis('off')

    cax = fig.add_axes([0.91, 0.1, 0.028, 0.8])
    fig.colorbar(im, cax=cax)
    cax2 = fig2.add_axes([0.91, 0.1, 0.028, 0.8])
    fig2.colorbar(im2, cax=cax2)

    plt.show()