def visual_test_spherical_poly_conditional(): import experiments.registration.gproc as gp import numpy.linalg as npl from mpl_toolkits.mplot3d import Axes3D import matplotlib.pyplot as plt import numpy as np #B = np.loadtxt('data/B.txt')[1:,:3] B = np.loadtxt('Ramon_dwi.bvecs')[:,:3] noise = np.array([21, 19, 26, 16, 20, 20, 26, 14, 24, 8, 37, 9, 39, 29, 17, 7, 13, 23, 12, 55, 40, 19, 6, 5, 20, 23, 31, 25, 14, 19, 22, 14], dtype=np.float64) #signal = np.array([186, 107, 167, 250, 170, 135, 93, 250, 138, 95, 169, # 207, 177, 160, 247, 188, 116, 235, 199, 192, 153, 237, # 176, 115, 228, 200, 90, 157, 194, 216, 94, 128], dtype=np.float64) signal = np.array(rcc_signal, dtype=np.float64) mean_signal = signal.mean() f_in = signal - mean_signal x_in = B.copy() sigmasq_signal = f_in.var() sigmasq_noise = noise.var() u = np.linspace(0, 2 * np.pi, 100) v = np.linspace(0, np.pi, 100) x = np.outer(np.cos(u), np.sin(v)) y = np.outer(np.sin(u), np.sin(v)) z = np.outer(np.ones(np.size(u)), np.cos(v)) x_out = np.array([x.reshape(-1),y.reshape(-1),z.reshape(-1)]).T mean_out, S_out = gp.spherical_poly_conditional(x_in, f_in, x_out, sigmasq_signal, sigmasq_noise) mean_out = np.array(mean_out) predicted = mean_out.reshape(x.shape) predicted += mean_signal x *= predicted y *= predicted z *= predicted fig = figure() ax = fig.add_subplot(111, projection='3d') ax.plot_surface(x, y, z, rstride=4, cstride=4, color='b', alpha=0.2, shade=True) points = diag(signal).dot(x_in) ax.scatter(points[:,0].copy(), points[:,1].copy(), points[:,2].copy(), c='r', s=40) ax.scatter(-1*points[:,0].copy(), -1*points[:,1].copy(), -1*points[:,2].copy(), c='r', s=40)
def draw_rect_dwi(event): global global_figure global sel_signal if event.inaxes != global_figure.axes[0]: return global global_map side = 2 px, py = int(event.xdata), int(event.ydata) dwi = global_map['dwi'] bvecs = global_map['bvecs'] dwi_name = global_map['dwi_name'] shape = dwi.shape residuals = global_map['residuals'] slice_type = global_map['slice_type'] slice_index = global_map['slice_index'] vmin = global_map['vmin'] vmax = global_map['vmax'] if slice_index is None: slice_index = dwi.shape[slice_type]//2 subsample0=None subsample1=None x, y, z = None, None, None plt.clf() ax = global_figure.add_subplot(1,2,1) if slice_type==0: ax.imshow(residuals[slice_index,:,:].T, origin='lower', vmin=vmin, vmax=vmax) x, y, z = slice_index, px, py elif slice_type==1: ax.imshow(residuals[:,slice_index,:].T, origin='lower', vmin=vmin, vmax=vmax) x, y, z = px, slice_index, py else: ax.imshow(residuals[:,:,slice_index].T, origin='lower', vmin=vmin, vmax=vmax) x, y, z = px, py, slice_index print("V[%d,%d,%d]=%f\n"%(x,y,z,residuals[x,y,z])) minx, maxx = max(0, x-side//2), min(shape[0]-1, x+side//2) miny, maxy = max(0, y-side//2), min(shape[1]-1, y+side//2) minz, maxz = max(0, z-side//2), min(shape[2]-1, z+side//2) sel_signal=dwi[x,y,z,:].copy() print(sel_signal) # Duplicate the points in the opposite direction points = np.empty((bvecs.shape[0] * 2, 3), dtype=np.float64) points[:bvecs.shape[0],:] = diag(sel_signal).dot(bvecs) points[bvecs.shape[0]:,:] = diag(sel_signal).dot(bvecs)*-1 # Fit Gaussian process mean_signal = sel_signal.mean() f_in = sel_signal - mean_signal x_in = bvecs.copy() sigmasq_signal = f_in.var() sigmasq_noise = 100.0 # We need to estimate this from the data # Create sampling points u = np.linspace(0, 2 * np.pi, 100) v = np.linspace(0, np.pi, 100) x = np.outer(np.cos(u), np.sin(v)) y = np.outer(np.sin(u), np.sin(v)) z = np.outer(np.ones(np.size(u)), np.cos(v)) x_out = np.array([x.reshape(-1),y.reshape(-1),z.reshape(-1)]).T # Get the conditional mean and covariance matrix mean_out, S_out = gp.spherical_poly_conditional(x_in, f_in, x_out, sigmasq_signal, sigmasq_noise) mean_out = np.array(mean_out) # Get the predicted signal at the new sampling points predicted = mean_out.reshape(x.shape) predicted += mean_signal x *= predicted y *= predicted z *= predicted ax = global_figure.add_subplot(122, projection='3d') ax.plot_surface(x, y, z, rstride=4, cstride=4, color='b', alpha=0.2, shade=True) points = diag(sel_signal).dot(x_in) ax.scatter(points[:,0].copy(), points[:,1].copy(), points[:,2].copy(), c='r', s=40) ax.scatter(-1*points[:,0].copy(), -1*points[:,1].copy(), -1*points[:,2].copy(), c='r', s=40) ax = global_figure.get_axes()[0] R = Rectangle((px-side//2,py-side//2), side, side, facecolor='none', linewidth=3, edgecolor='#DD0000') if len(ax.artists)>0: ax.artists[-1].remove() ax.add_artist(R) draw()