def visual_test_spherical_poly_conditional():
    import experiments.registration.gproc as gp
    import numpy.linalg as npl
    from mpl_toolkits.mplot3d import Axes3D
    import matplotlib.pyplot as plt
    import numpy as np

    #B = np.loadtxt('data/B.txt')[1:,:3]
    B = np.loadtxt('Ramon_dwi.bvecs')[:,:3]

    noise = np.array([21, 19, 26, 16, 20, 20, 26, 14, 24,  8,
                      37,  9, 39, 29, 17,  7, 13, 23, 12, 55, 40,
                      19,  6,  5, 20, 23, 31, 25, 14, 19, 22, 14], dtype=np.float64)
    #signal = np.array([186, 107, 167, 250, 170, 135, 93, 250, 138,  95, 169,
    #                   207, 177, 160, 247, 188, 116, 235, 199, 192, 153, 237,
    #                   176, 115, 228, 200, 90, 157, 194, 216, 94, 128], dtype=np.float64)
    signal = np.array(rcc_signal, dtype=np.float64)
    mean_signal = signal.mean()
    f_in = signal - mean_signal
    x_in = B.copy()
    sigmasq_signal = f_in.var()
    sigmasq_noise = noise.var()

    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)

    x = np.outer(np.cos(u), np.sin(v))
    y = np.outer(np.sin(u), np.sin(v))
    z = np.outer(np.ones(np.size(u)), np.cos(v))
    x_out = np.array([x.reshape(-1),y.reshape(-1),z.reshape(-1)]).T
    mean_out, S_out = gp.spherical_poly_conditional(x_in, f_in, x_out, sigmasq_signal, sigmasq_noise)
    mean_out = np.array(mean_out)
    predicted = mean_out.reshape(x.shape)
    predicted += mean_signal
    x *= predicted
    y *= predicted
    z *= predicted
    fig = figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(x, y, z,  rstride=4, cstride=4, color='b', alpha=0.2, shade=True)
    points = diag(signal).dot(x_in)
    ax.scatter(points[:,0].copy(), points[:,1].copy(), points[:,2].copy(), c='r', s=40)
    ax.scatter(-1*points[:,0].copy(), -1*points[:,1].copy(), -1*points[:,2].copy(), c='r', s=40)
def draw_rect_dwi(event):
    global global_figure
    global sel_signal
    if event.inaxes != global_figure.axes[0]:
        return
    global global_map
    side = 2
    px, py = int(event.xdata), int(event.ydata)

    dwi = global_map['dwi']
    bvecs = global_map['bvecs']
    dwi_name = global_map['dwi_name']
    shape = dwi.shape
    residuals = global_map['residuals']
    slice_type = global_map['slice_type']
    slice_index = global_map['slice_index']
    vmin = global_map['vmin']
    vmax = global_map['vmax']
    if slice_index is None:
        slice_index = dwi.shape[slice_type]//2

    subsample0=None
    subsample1=None
    x, y, z = None, None, None
    plt.clf()
    ax = global_figure.add_subplot(1,2,1)
    if slice_type==0:
        ax.imshow(residuals[slice_index,:,:].T, origin='lower', vmin=vmin, vmax=vmax)
        x, y, z = slice_index, px, py
    elif slice_type==1:
        ax.imshow(residuals[:,slice_index,:].T, origin='lower', vmin=vmin, vmax=vmax)
        x, y, z = px, slice_index, py
    else:
        ax.imshow(residuals[:,:,slice_index].T, origin='lower', vmin=vmin, vmax=vmax)
        x, y, z = px, py, slice_index
    print("V[%d,%d,%d]=%f\n"%(x,y,z,residuals[x,y,z]))
    minx, maxx = max(0, x-side//2), min(shape[0]-1, x+side//2)
    miny, maxy = max(0, y-side//2), min(shape[1]-1, y+side//2)
    minz, maxz = max(0, z-side//2), min(shape[2]-1, z+side//2)
    sel_signal=dwi[x,y,z,:].copy()
    print(sel_signal)
    # Duplicate the points in the opposite direction
    points = np.empty((bvecs.shape[0] * 2, 3), dtype=np.float64)
    points[:bvecs.shape[0],:] = diag(sel_signal).dot(bvecs)
    points[bvecs.shape[0]:,:] = diag(sel_signal).dot(bvecs)*-1

    # Fit Gaussian process
    mean_signal = sel_signal.mean()
    f_in = sel_signal - mean_signal
    x_in = bvecs.copy()
    sigmasq_signal = f_in.var()
    sigmasq_noise = 100.0 # We need to estimate this from the data
    # Create sampling points
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)
    x = np.outer(np.cos(u), np.sin(v))
    y = np.outer(np.sin(u), np.sin(v))
    z = np.outer(np.ones(np.size(u)), np.cos(v))
    x_out = np.array([x.reshape(-1),y.reshape(-1),z.reshape(-1)]).T
    # Get the conditional mean and covariance matrix
    mean_out, S_out = gp.spherical_poly_conditional(x_in, f_in, x_out, sigmasq_signal, sigmasq_noise)
    mean_out = np.array(mean_out)
    # Get the predicted signal at the new sampling points
    predicted = mean_out.reshape(x.shape)
    predicted += mean_signal
    x *= predicted
    y *= predicted
    z *= predicted
    ax = global_figure.add_subplot(122, projection='3d')
    ax.plot_surface(x, y, z,  rstride=4, cstride=4, color='b', alpha=0.2, shade=True)
    points = diag(sel_signal).dot(x_in)
    ax.scatter(points[:,0].copy(), points[:,1].copy(), points[:,2].copy(), c='r', s=40)
    ax.scatter(-1*points[:,0].copy(), -1*points[:,1].copy(), -1*points[:,2].copy(), c='r', s=40)

    ax = global_figure.get_axes()[0]
    R = Rectangle((px-side//2,py-side//2), side, side, facecolor='none', linewidth=3, edgecolor='#DD0000')
    if len(ax.artists)>0:
        ax.artists[-1].remove()
    ax.add_artist(R)
    draw()