def test(model, env, rendering=True, max_timesteps=3000): result = {} result["model"] = model result["maps"] = test_seeds for seed in test_seeds: result[seed] = {} for seed in test_seeds: env.seed(seed) episode_reward = 0 step = 0 state = env.reset() while True: state = np.array(state) state = data_transform(state) state = state.unsqueeze(0) model.eval() model_action = model(state) a = model_action.detach().numpy()[0] state, r, done, info = env.step(a) episode_reward += r step += 1 if rendering: env.render() if done or step > max_timesteps: break print("Track reward: {}".format(episode_reward)) result[seed]["reward"] = episode_reward save_result(result)
def inference(): unet = U_Net(in_ch=1, out_ch=1).cuda() # unet.load_state_dict(torch.load('best.pth')) unet.load_state_dict(torch.load('实验/关于网络结构/lena/slim-8-last.pth')) # 用于训练网络的原始图像及其傅里叶intensity image = cv2.imread('lena.jpg', cv2.IMREAD_GRAYSCALE) # image = (255 - image)*0.1 image = cv2.resize(image, (256, 256)) # 支撑域 # reduced_image = cv2.resize(image, (128, 128)) # image = np.zeros((256, 256)) # image[64:192, 64:192] = reduced_image phase_obj = image / 255 * 2 * np.pi * 0.5 #* 0.5 magnitudes = np.abs(np.fft.fft2(np.exp(1j * phase_obj))) magnitudes_t = Variable(data_transform(magnitudes).unsqueeze(0)) retrieved_phase = unet( magnitudes_t.type(dtype)).data.cpu().squeeze(0).numpy()[0] # phase shift问题,减最小值处理 phase_obj -= np.min(phase_obj) retrieved_phase -= np.min(retrieved_phase) plt.figure() plt.imshow(phase_obj / (2 * np.pi), cmap='gray', vmin=0, vmax=0.5) plt.colorbar(ticks=[0, 0.5]) plt.axis('off') plt.figure() plt.imshow(retrieved_phase / (2 * np.pi), cmap='gray', vmin=0, vmax=0.5) plt.colorbar(ticks=[0, 0.5]) plt.axis('off') plt.show()
def __getitem__(self, index): tensor = data_transform(self.tensors[0][index]) return (tensor, ) + tuple(t[index] for t in self.tensors[1:])
def run_fresnel(): # 参数设置 wavelength = 632.8e-6 N = 256 width = 10 k = 2 * np.pi / wavelength f = 2 # z = 1000 tmp = np.arange(-width // 2, width // 2, width / N) x, y = np.meshgrid(tmp, tmp) R = np.sqrt(x**2 + y**2) R[R > width / 2] = 0 lens = np.exp(-1j * k * R**2 / (2 * f)) * cyl(x, y, width / 2) # phase image 读取 image = cv2.imread('peppers_gray.tif', cv2.IMREAD_GRAYSCALE) print('original image shape ', image.shape) image = cv2.resize(image, (256, 256)) print('original image max ', np.max(image)) # 如果是phase object,对灰度图缩放到0-2π之间 phase_obj = image / 255 * 2 * np.pi * 0.5 #* 0.5 # phase_obj = 0 print(np.max(phase_obj), np.min(phase_obj)) # network input u0 = lens * np.exp(1j * phase_obj) m_d2_d1 = 0.001 uz = two_step_prop_fresnel(u0, wavelength, width / N, m_d2_d1 * width / N, f) magnitudes = np.abs(uz) print('original amplitude max: %f min: %f' % (np.max(magnitudes), np.min(magnitudes))) # 增加离焦项 defocus1 = 0.0005 defocus2 = 0.0001 magnitudes1 = np.abs( two_step_prop_fresnel(u0, wavelength, width / N, m_d2_d1 * width / N, f + defocus1)) magnitudes2 = np.abs( two_step_prop_fresnel(u0, wavelength, width / N, m_d2_d1 * width / N, f + defocus2)) print( 'rmse between in-focus and defocus intensity, first: %f second: %f' % (np.sqrt(np.mean((magnitudes - magnitudes1)** 2)), np.sqrt(np.mean( (magnitudes - magnitudes2)**2)))) plt.figure() plt.plot(magnitudes[128, :]) plt.show() plt.subplot(131) plt.imshow(magnitudes, cmap='gray') plt.title('magnitudes') plt.colorbar() plt.subplot(132) plt.imshow(magnitudes1, cmap='gray') plt.title('magnitudes1') plt.subplot(133) plt.imshow(magnitudes2, cmap='gray') plt.title('magnitudes2') plt.show() # plt.figure() # plt.plot(magnitudes[0, :], color='r') # plt.plot(magnitudes1[0, :], color='g') # plt.plot(magnitudes2[0, :], color='b') # plt.show() # magnitudes_t = torch.Tensor(magnitudes) magnitudes_t = Variable(data_transform(magnitudes).unsqueeze(0)) magnitudes_t1 = Variable(data_transform(magnitudes1).unsqueeze(0)) magnitudes_t2 = Variable(data_transform(magnitudes2).unsqueeze(0)) # 添加正态分布的随机噪声? # noise = magnitudes_t.clone() # noise1 = magnitudes_t1.clone() # noise2 = magnitudes_t2.clone() # magnitudes_t += Variable(noise.normal_()*10) # # print('psnr of the original image and noises: ', psnr(magnitudes_t.data.numpy()/np.max(tmp), tmp/np.max(tmp))) # # plt.subplot(221) # # plt.imshow(tmp.squeeze()) # # plt.subplot(222) # # plt.imshow(magnitudes_t.squeeze().numpy()) # # plt.subplot(212) # # plt.plot(tmp.squeeze()[0,:], color='r') # # plt.plot(magnitudes_t.squeeze().numpy()[0,:], color='b') # # plt.show() # magnitudes_t1 += Variable(noise1.normal_()*10) # magnitudes_t2 += Variable(noise2.normal_()*10) # 多调制 mse_loss, mse_loss2, retrieved_phase = retrieve_fresnel( magnitudes_t, magnitudes_t1, magnitudes_t2, phase_obj, defocus1, defocus2) # 单调制 # mse_loss, mse_loss2, retrieved_phase = retrieve_fresnel(magnitudes_t, magnitudes_t1, magnitudes_t2, # phase_obj, defocus1, None) # 无调制 # mse_loss, mse_loss2, retrieved_phase = retrieve_fresnel(magnitudes_t, magnitudes_t1, magnitudes_t2, # phase_obj, None, None) np.save('mse_loss_2.npy', mse_loss2) # retrieved_phase = np.mod(retrieved_phase, 2*np.pi) retrieved_phase *= 2 * np.pi * 0.5 # print(phase_obj[150:156, 150:156]) # print(retrieved_phase[150:156, 150:156]) # print(phase_obj[0:6, 0:6]) # print(retrieved_phase[0:6, 0:6]) # phase shift问题,减最小值处理 phase_obj -= np.min(phase_obj) retrieved_phase -= np.min(retrieved_phase) print('phase rmse: ', np.sqrt(np.mean((phase_obj - retrieved_phase)**2))) plt.subplot(221) plt.imshow(phase_obj, cmap='gray') plt.title('phase_obj') plt.axis('off') plt.subplot(222) plt.imshow(retrieved_phase, cmap='gray') plt.title('retrieved_phase') plt.axis('off') plt.colorbar() plt.subplot(212) plt.plot(phase_obj[128, :], color='r') plt.plot(retrieved_phase[128, :], color='b') plt.show() real_part = torch.cos( torch.tensor(retrieved_phase).type(dtype)).unsqueeze(-1) image_part = torch.sin( torch.tensor(retrieved_phase).type(dtype)).unsqueeze(-1) complex_phase = torch.cat((real_part, image_part), dim=-1).squeeze() f_phase = torch.fft(complex_phase, signal_ndim=2) re = torch.index_select(f_phase, dim=2, index=torch.tensor(0).type(torch.cuda.LongTensor)) im = torch.index_select(f_phase, dim=2, index=torch.tensor(1).type(torch.cuda.LongTensor)) pred_intensity = torch.sqrt(re**2 + im**2).squeeze().data.cpu().squeeze().numpy() check_magnitudes = np.abs( two_step_prop_fresnel(lens * np.exp(1j * retrieved_phase), wavelength, width / N, m_d2_d1 * width / N, f)) np.save('phase_obj', phase_obj) np.save('retrieved_phase', retrieved_phase) np.save('magnitudes', magnitudes) np.save('check_magnitudes', check_magnitudes) # print(magnitudes[150:156, 150:156]) # print(pred_intensity[150:156, 150:156]) # print(magnitudes[0:6, 0:6]) # print(pred_intensity[0:6, 0:6]) print('intensity rmse: ', np.sqrt(np.mean( (magnitudes - pred_intensity)**2))) print('rmse between original magnitudes and fft of pred_phase:', np.sqrt(np.mean((magnitudes - check_magnitudes)**2))) print('magnitudes shape: ', magnitudes.shape) print('check_magnitudes shape: ', check_magnitudes.shape)
def run(): # np.random.seed(233) # image = imageio.imread('cameraman.png', as_gray=True) # image = cv2.imread('mandril_gray.tif', cv2.IMREAD_GRAYSCALE) image = cv2.imread('zju.jpg', cv2.IMREAD_GRAYSCALE) image = (255 - image) * 0.5 # 使用test.jpg的时候记得打开这个 否则边缘为pi 中间为0 print('original image shape ', image.shape) image = cv2.resize(image, (256, 256)) # 支撑域?这个概念的理解还有点问题 # reduced_image = cv2.resize(image, (128, 128)) # image = np.zeros((256, 256)) # image[64:192, 64:192] = reduced_image print('original image max ', np.max(image)) # 如果是phase object,对灰度图缩放到0-2π之间 phase_obj = image / 255 * 2 * np.pi * PHASE_RANGE #* 0.5 print('max and min of phase_obj: ', np.max(phase_obj), np.min(phase_obj)) # network input magnitudes = np.abs(np.fft.fft2(np.exp(1j * phase_obj))) # 归一化振幅的相位调制 print('original amplitude max: %f min: %f' % (np.max(magnitudes), np.min(magnitudes))) # 增加一个离焦项,使用泽尼克来实现 # 可以是离焦调制,也完全可以是其他已知的调制? # deep phase decoder相当于是加了相位调制? # ----- defocus_term = imageio.imread('modulate1.jpg', as_gray=True) defocus_term = cv2.resize(defocus_term, (256, 256)) # defocus_term = defocus_term / 255 * 2 * np.pi # 先在这一项里设出2pi,这样系数回归区间只需要在0-1 defocus_term = defocus_term / 255 # 已经值net输出位置设置了乘2pi defocus_term = Variable(data_transform(defocus_term)).type(dtype) # ----- modulate_phase_image1 = imageio.imread('modulate1.jpg', as_gray=True) modulate_phase_image1 = cv2.resize(modulate_phase_image1, (256, 256)) phase_modulate1 = modulate_phase_image1 / 255 * 2 * np.pi * 5.12 # 这个离焦参数 设为网络自动学习? phase_obj1 = phase_obj + phase_modulate1 # 加两种不同的调制试试 modulate_phase_image2 = cv2.imread('modulate1.jpg', cv2.IMREAD_GRAYSCALE) modulate_phase_image2 = cv2.resize(modulate_phase_image2, (256, 256)) phase_modulate2 = modulate_phase_image2 / 255 * 2 * np.pi * 10.81 phase_obj2 = phase_obj + phase_modulate2 # 三图调制 modulate_phase_image3 = cv2.imread('modulate1.jpg', cv2.IMREAD_GRAYSCALE) modulate_phase_image3 = cv2.resize(modulate_phase_image3, (256, 256)) phase_modulate3 = modulate_phase_image3 / 255 * 2 * np.pi * 15.77 phase_obj3 = phase_obj + phase_modulate3 # magnitudes = np.abs(np.fft.fftshift(np.fft.fft2(np.exp(1j*phase_obj)))) # 若此处shift,则计算loss时也要shift magnitudes1 = np.abs(np.fft.fft2(np.exp(1j * phase_obj1))) print('rmse between in-focus and defocus intensity 1: ', np.sqrt(np.mean((magnitudes - magnitudes1)**2))) magnitudes2 = np.abs(np.fft.fft2(np.exp(1j * phase_obj2))) print('rmse between in-focus and defocus intensity 2: ', np.sqrt(np.mean((magnitudes - magnitudes2)**2))) magnitudes3 = np.abs(np.fft.fft2(np.exp(1j * phase_obj3))) print('rmse between in-focus and defocus intensity 3: ', np.sqrt(np.mean((magnitudes - magnitudes3)**2))) # ---- 灰度级离散化 ----- magnitudes = np.square(magnitudes) magnitudes1 = np.square(magnitudes1) magnitudes2 = np.square(magnitudes2) magnitudes3 = np.square(magnitudes3) plt.figure() plt.imshow(magnitudes2) plt.show() # 14位 print("max of intensity: %d %d %d %d" % (np.max(magnitudes), np.max(magnitudes1), np.max(magnitudes2), np.max(magnitudes3))) # focus的中心能量最高,用它 # 实际中,每个焦面的光强无法用统一值做归一化,因为曝光时间不一样! # magnitudes = np.floor(magnitudes / np.max(magnitudes) * np.power(2, 14)) # 这个floor肯定导致了一些零值 # magnitudes1 = np.floor(magnitudes1 / np.max(magnitudes1) * np.power(2, 14)) # magnitudes2 = np.floor(magnitudes2 / np.max(magnitudes2) * np.power(2, 14)) # magnitudes3 = np.floor(magnitudes3 / np.max(magnitudes3) * np.power(2, 14)) # 过曝 bit_camera = 14 exposure_time = 10 magnitudes = np.mod( np.floor(magnitudes / np.max(magnitudes) * np.power(2, bit_camera) * exposure_time), np.power(2, bit_camera)) # 这个floor肯定导致了一些零值 magnitudes1 = np.mod( np.floor(magnitudes1 / np.max(magnitudes1) * np.power(2, bit_camera) * exposure_time), np.power(2, bit_camera)) magnitudes2 = np.mod( np.floor(magnitudes2 / np.max(magnitudes2) * np.power(2, bit_camera) * exposure_time), np.power(2, bit_camera)) magnitudes3 = np.mod( np.floor(magnitudes3 / np.max(magnitudes3) * np.power(2, bit_camera) * exposure_time), np.power(2, bit_camera)) # 实际中肯定不可能刚好归一化,而是先过曝 # normalize_factor = np.max(magnitudes) # magnitudes = np.floor(magnitudes / normalize_factor * np.power(2, 14)) # magnitudes1 = np.floor(magnitudes1 / normalize_factor * np.power(2, 14)) # magnitudes2 = np.floor(magnitudes2 / normalize_factor * np.power(2, 14)) # magnitudes3 = np.floor(magnitudes3 / normalize_factor * np.power(2, 14)) # 实际中光强未知,干脆用归一化的intensity计算loss? magnitudes /= np.power(2, bit_camera) magnitudes1 /= np.power(2, bit_camera) magnitudes2 /= np.power(2, bit_camera) magnitudes3 /= np.power(2, bit_camera) # 16位 # 还原 magnitudes = np.sqrt(magnitudes) magnitudes1 = np.sqrt(magnitudes1) magnitudes2 = np.sqrt(magnitudes2) magnitudes3 = np.sqrt(magnitudes3) # ---------------------- # 保存三个intensity # np.save('intensity1.npy', magnitudes**2) # np.save('intensity2.npy', magnitudes1**2) # np.save('intensity3.npy', magnitudes2**2) # plt.subplot(121) # plt.imshow(magnitudes1, cmap='gray') # plt.title('magnitudes1') # plt.subplot(122) # plt.imshow(magnitudes2, cmap='gray') # plt.title('magnitudes2') # plt.show() # plt.figure() # plt.plot(magnitudes[0, :], color='r') # plt.plot(magnitudes1[0, :], color='g') # plt.plot(magnitudes2[0, :], color='b') # plt.show() # magnitudes_t = torch.Tensor(magnitudes) magnitudes_t = Variable( data_transform(magnitudes).unsqueeze(0)) # 三个振幅 而非强度 magnitudes_t1 = Variable(data_transform(magnitudes1).unsqueeze(0)) magnitudes_t2 = Variable(data_transform(magnitudes2).unsqueeze(0)) magnitudes_t3 = Variable(data_transform(magnitudes3).unsqueeze(0)) # 添加正态分布的随机噪声? # noise = magnitudes_t.clone() # noise1 = magnitudes_t1.clone() # noise2 = magnitudes_t2.clone() # magnitudes_t += Variable(noise.normal_()*10) # # print('psnr of the original image and noises: ', psnr(magnitudes_t.data.numpy()/np.max(tmp), tmp/np.max(tmp))) # # plt.subplot(221) # # plt.imshow(tmp.squeeze()) # # plt.subplot(222) # # plt.imshow(magnitudes_t.squeeze().numpy()) # # plt.subplot(212) # # plt.plot(tmp.squeeze()[0,:], color='r') # # plt.plot(magnitudes_t.squeeze().numpy()[0,:], color='b') # # plt.show() # magnitudes_t1 += Variable(noise1.normal_()*10) # magnitudes_t2 += Variable(noise2.normal_()*10) # 多调制 mse_loss, mse_loss2, retrieved_phase = retrieve( magnitudes_t, magnitudes_t1, magnitudes_t2, magnitudes_t3, phase_obj, phase_obj1, phase_obj2, phase_obj3, phase_modulate1, phase_modulate2, phase_modulate3, defocus_term) # 单调制 # mse_loss, mse_loss2, retrieved_phase = retrieve(magnitudes_t, magnitudes_t1, magnitudes_t2, # phase_obj, phase_obj1, None, # phase_modulate1, None) # 无调制 # mse_loss, mse_loss2, retrieved_phase = retrieve(magnitudes_t, magnitudes_t1, magnitudes_t2, # phase_obj, None, None, # None, None) np.save('mse_loss_2.npy', mse_loss2) # retrieved_phase = np.mod(retrieved_phase, 2*np.pi) # retrieved_phase *= 2*np.pi*0.5 # retrieved_phase *= 2*np.pi*PHASE_RANGE retrieved_phase *= 1 # print(phase_obj[150:156, 150:156]) # print(retrieved_phase[150:156, 150:156]) # print(phase_obj[0:6, 0:6]) # print(retrieved_phase[0:6, 0:6]) # phase shift问题,减最小值处理 phase_obj -= np.min(phase_obj) retrieved_phase -= np.min(retrieved_phase) print('phase rmse: ', np.sqrt(np.mean((phase_obj - retrieved_phase)**2))) plt.subplot(221) plt.imshow(phase_obj, cmap='gray') plt.title('phase_obj') plt.axis('off') plt.subplot(222) plt.imshow(retrieved_phase, cmap='gray') plt.title('retrieved_phase') plt.axis('off') plt.colorbar() plt.subplot(212) plt.plot(phase_obj[128, :], color='r') plt.plot(retrieved_phase[128, :], color='b') plt.show() real_part = torch.cos( torch.tensor(retrieved_phase).type(dtype)).unsqueeze(-1) image_part = torch.sin( torch.tensor(retrieved_phase).type(dtype)).unsqueeze(-1) complex_phase = torch.cat((real_part, image_part), dim=-1).squeeze() f_phase = torch.fft(complex_phase, signal_ndim=2) re = torch.index_select(f_phase, dim=2, index=torch.tensor(0).type(torch.cuda.LongTensor)) im = torch.index_select(f_phase, dim=2, index=torch.tensor(1).type(torch.cuda.LongTensor)) pred_intensity = torch.sqrt(re**2 + im**2).squeeze().data.cpu().squeeze().numpy() check_magnitudes = np.abs(np.fft.fft2(np.exp(1j * (retrieved_phase)))) np.save('phase_obj', phase_obj) np.save('retrieved_phase', retrieved_phase) np.save('magnitudes', np.abs(np.fft.fft2(np.exp(1j * phase_obj)))) np.save('check_magnitudes', np.abs(np.fft.fft2(np.exp(1j * (retrieved_phase))))) # print(magnitudes[150:156, 150:156]) # print(pred_intensity[150:156, 150:156]) # print(magnitudes[0:6, 0:6]) # print(pred_intensity[0:6, 0:6]) print('intensity rmse: ', np.sqrt(np.mean( (magnitudes - pred_intensity)**2))) print('rmse between original magnitudes and fft of pred_phase:', np.sqrt(np.mean((magnitudes - check_magnitudes)**2))) print('magnitudes shape: ', magnitudes.shape) print('check_magnitudes shape: ', check_magnitudes.shape)
def __getitem__(self, idx): return self.actions[idx], data_transform(self.states[idx])
def inference_exp(): # -------加载system结果 ------------ unet1 = U_Net(in_ch=1, out_ch=1).cuda() unet1.load_state_dict( torch.load('D:/00 论文相关/毕设/实验/恢复结果/0105-圆孔-2pi/best.pth')) # 用于训练网络的原始图像及其傅里叶intensity input = np.load('D:/00 论文相关/毕设/实验/恢复结果/0105-圆孔-2pi/f.npy') input = input / np.max(input) input = np.sqrt(input) input = np.fft.ifftshift(input) magnitudes_t = Variable(data_transform(input).unsqueeze(0)) # phase shift问题,减最小值处理 img_size = 1024 full_size_aperture = 34.31 tmp = np.arange( -full_size_aperture / 2, full_size_aperture / 2 + full_size_aperture / (img_size - 1), full_size_aperture / (img_size - 1)) print(tmp[-1] - tmp[0]) x, y = np.meshgrid(tmp, tmp) diameter = 10 defocus_term = (x**2 + y**2) * 2 * np.pi / (632.8e-6 * 2 * 200**2) * 50 # 单位mm defocus_term[np.sqrt(x**2 + y**2) > diameter / 2] = 0 defocus_term = Variable( data_transform(defocus_term).unsqueeze(0)).type(dtype) retrieved_phase_sys, out_d1, out_d2, out_d3, _, _, _ = unet1( magnitudes_t.type(dtype), defocus_term) retrieved_phase_sys = retrieved_phase_sys.data.cpu().squeeze().numpy() print(retrieved_phase_sys.shape) min_0 = np.min(retrieved_phase_sys[np.sqrt(x**2 + y**2) < diameter / 2]) retrieved_phase_sys[np.sqrt(x**2 + y**2) < diameter / 2] -= min_0 # 外面不减里面减 # ----- 加载样品结果 ------------ unet2 = U_Net(in_ch=1, out_ch=1).cuda() unet2.load_state_dict( torch.load('D:/00 论文相关/毕设/实验/恢复结果/0105-zju-2pi-2/best.pth')) # 用于训练网络的原始图像及其傅里叶intensity input = np.load('D:/00 论文相关/毕设/实验/恢复结果/0105-zju-2pi-2/f.npy') input = input / np.max(input) input = np.sqrt(input) input = np.fft.ifftshift(input) magnitudes_t = Variable(data_transform(input).unsqueeze(0)) retrieved_phase, out_d1, out_d2, out_d3, _, _, _ = unet2( magnitudes_t.type(dtype), defocus_term) retrieved_phase = retrieved_phase.data.cpu().squeeze().numpy() print(retrieved_phase.shape) min_0 = np.min(retrieved_phase[np.sqrt(x**2 + y**2) < diameter / 2]) retrieved_phase[np.sqrt(x**2 + y**2) < diameter / 2] -= min_0 # 外面不减里面减 # --- 对结果进行低通滤波 print('max and min of sys: ', np.max(retrieved_phase_sys), np.min(retrieved_phase_sys)) print('max and min of zju: ', np.max(retrieved_phase), np.min(retrieved_phase)) retrieved_phase_sys = cv2.medianBlur(retrieved_phase_sys, 3) # retrieved_phase = cv2.medianBlur(retrieved_phase, 3) # 双边滤波 # retrieved_phase = cv2.bilateralFilter(retrieved_phase, 5,0.3, 1) # --------- 校正 ------------- retrieved_phase_corr = retrieved_phase - retrieved_phase_sys retrieved_phase_corr = cv2.medianBlur(retrieved_phase_corr, 3) retrieved_phase_corr = cv2.bilateralFilter(retrieved_phase_corr, 5, 0.3, 1) # 二维图 plt.figure() plt.imshow(retrieved_phase_sys[300:-300, 300:-300] / (2 * np.pi), cmap='gray') plt.axis('off') plt.title('result_sys') plt.colorbar() plt.figure() # plt.imshow(retrieved_phase/(2*np.pi), cmap='gray', vmin=0, vmax=0.5) plt.imshow(retrieved_phase[300:-300, 300:-300] / (2 * np.pi), cmap='gray') # plt.colorbar(ticks=[0, 0.5]) plt.axis('off') plt.title('result_tmp') plt.colorbar() plt.figure() # plt.imshow(retrieved_phase/(2*np.pi), cmap='gray', vmin=0, vmax=0.5) plt.imshow(retrieved_phase_corr[300:-300, 300:-300] / (2 * np.pi), cmap='gray') # plt.colorbar(ticks=[0, 0.5]) plt.plot(range(100, 300), [212] * 200, color=pcolor['red'], linestyle='--', linewidth=4) plt.axis('off') plt.title('result_zju') plt.colorbar() # 一维图 plt.figure() plt.plot(range(1024), retrieved_phase_corr[512, :] / (2 * np.pi), color=pcolor['red'], linestyle='-', linewidth=2) # 看下全貌 # plt.plot(range(100, 300), retrieved_phase_corr[512, 400:-424]/(2*np.pi), color=pcolor['red'], linestyle='-', linewidth=2) plt.legend(['experimental results'], fontsize=20, loc='upper left') # plt.ylim([0,0.5]) # plt.xlim([75,150]) # plt.yticks([0,0.25,0.5]) # plt.xticks([75, 112, 150], fontsize=20) # plt.yticks(fontsize=20) # 三维图 fig = plt.figure() x, y = np.meshgrid(range(424), range(424)) # 根据多项式拟合结果减去那个离焦面 x1, y1 = x - 212, y - 212 coff = [-2.224758e-04, -1.040499e-03, 2.35139] generate = (x1**2 + y1**2) * coff[0] + (x1 + y1) * coff[1] + coff[2] generate = generate / (2 * np.pi) # final = tmp-generate # final = final - np.min(final) # plt.imshow(final[140:270, :], cmap='gray') # plt.colorbar() # ax = plt.axes(projection='3d') tmp = retrieved_phase_corr[300:-300, 300:-300] / (2 * np.pi) # tmp = tmp - generate copy = tmp.copy() tmp[copy != 0] = tmp[copy != 0] - generate[copy != 0] test = tmp[142:271, 70:353] # test = tmp[142:271, 75:348] plt.figure() # plt.plot(test[test.shape[0]//2, :]) plt.imshow(test, vmin=-0.2, vmax=0.5) # test[test==0] = None print('result pv: %f rad' % (2 * np.pi * (np.max(test) - np.min(test)))) print('result rms: %f rad' % (2 * np.pi * (np.sqrt(np.sum(test**2) / (test.shape[0] * test.shape[1]))))) plt.colorbar(ticks=[-0.2, 0, 0.5]) plt.show() plt.imshow(tmp, cmap='gray') plt.axis('off') plt.title('result') plt.colorbar() tmp[copy == 0] = None fig = plt.figure() ax = Axes3D(fig) ax.plot_surface(x, y, tmp[::-1, :], rstride=10, cstride=10, cmap='jet', edgecolor='none', vmin=0, vmax=0.5) ax.set_zticks([-0.2, 0.15, 0.5]) ax.view_init(azim=-124, elev=83) ax.set_title('sample') # plt.contour3D(x, y, ) # # 根据多项式拟合结果减去那个离焦面 # x, y = x-212, y-212 # coff = [-2.214758e-04, 3.140499e-03, 2.55139] # generate = (x**2+y**2)*coff[0] + (x+y)*coff[1] + coff[2] # # plt.figure() # # plt.imshow(generate) # generate = generate / (2*np.pi) # # plt.figure() # final = tmp-generate # # final = final - np.min(final) # plt.imshow(final[140:270, :], cmap='gray') # plt.colorbar() # 多项式拟合 # plt.figure() # plt.plot(range(img_size), retrieved_phase_corr[:, img_size//2]) # f1 = np.polyfit(range(-145, 150), retrieved_phase_corr[365:660, img_size//2], 4) # print(f1) # p1 = np.poly1d(f1) # fit_value = p1(range(-145, 150)) # plt.plot(range(365, 660), fit_value) # f1 = np.polyfit(range(-119, 120), retrieved_phase[391:630, img_size//2], 2) # print(f1) # p1 = np.poly1d(f1) # fit_value = p1(range(-119, 120)) # plt.plot(range(391, 630), fit_value) plt.show()
def test(): # ---- read and process intensity images ---- # # bkg1 = np.load('exp/w_sample/f.npy') # 目前有三处孔径设置,1为defocus处,2为net.py的cyl层,3为fit.py的calc_intensity处 # 三处size,net的cyl,calc_intensity的cyl img_size = 1024 # 无样品圆孔 # j1 = np.load('./exp/1221/wo_sample/combined/%s/f_less_exposed.npy' % img_size) # df1 = np.load('./exp/1221/wo_sample/combined/%s/df4.npy' % img_size) # df2 = np.load('./exp/1221/wo_sample/combined/%s/df7.npy' % img_size) # df3 = np.load('./exp/1221/wo_sample/combined/%s/df10.npy' % img_size) # 无样品 圆孔去噪 # j1 = np.load('./exp/1221/wo_sample/combined/denoised/%s/f.npy' % img_size) # df1 = np.load('./exp/1221/wo_sample/combined/denoised/%s/df4.npy' % img_size) # df2 = np.load('./exp/1221/wo_sample/combined/denoised/%s/df7.npy' % img_size) # df3 = np.load('./exp/1221/wo_sample/combined/denoised/%s/df10.npy' % img_size) # # 有样品圆孔(圆环板) # j1 = np.load('./exp/1221/w_sample/combined/%s/f.npy' % img_size) # df1 = np.load('./exp/1221/w_sample/combined/%s/df4.npy' % img_size) # df2 = np.load('./exp/1221/w_sample/combined/%s/df7.npy' % img_size) # df3 = np.load('./exp/1221/w_sample/combined/%s/df10.npy' % img_size) # 有样品圆孔去噪 孔径5mm # j1 = np.load('./exp/1221/w_sample/combined/denoised/%s/f.npy' % img_size) # df1 = np.load('./exp/1221/w_sample/combined/denoised/%s/df4.npy' % img_size) # df2 = np.load('./exp/1221/w_sample/combined/denoised/%s/df7.npy' % img_size) # df3 = np.load('./exp/1221/w_sample/combined/denoised/%s/df10.npy' % img_size) # print(np.min(j1), np.min(df1), np.min(df2), np.min(df3)) # 无样品圆孔去噪 孔径10mm # j1 = np.load('./exp/1228/wo_sample/combined/denoised/%s/f.npy' % img_size) # df1 = np.load('./exp/1228/wo_sample/combined/denoised/%s/df4.npy' % img_size) # df2 = np.load('./exp/1228/wo_sample/combined/denoised/%s/df7.npy' % img_size) # df3 = np.load('./exp/1228/wo_sample/combined/denoised/%s/df10.npy' % img_size) # 有样品zju去噪 孔径10mm # j1 = np.load('./exp/1228/w_sample/combined/denoised/%s/f.npy' % img_size) # df1 = np.load('./exp/1228/w_sample/combined/denoised/%s/df4.npy' % img_size) # df2 = np.load('./exp/1228/w_sample/combined/denoised/%s/df7_2.npy' % img_size) # df3 = np.load('./exp/1228/w_sample/combined/denoised/%s/df10_2.npy' % img_size) # 有样品zju去噪 孔径10mm by 1230 # j1 = np.load('./exp/1230/w_sample/combined/denoised/%s/f.npy' % img_size) # df1 = np.load('./exp/1230/w_sample/combined/denoised/%s/df4.npy' % img_size) # df2 = np.load('./exp/1230/w_sample/combined/denoised/%s/df7_2.npy' % img_size) # df3 = np.load('./exp/1230/w_sample/combined/denoised/%s/df10.npy' % img_size) # 有样品zju去噪 孔径10mm by 0103 # j1 = np.load('./exp/0103/w_sample/combined/denoised/%s/f.npy' % img_size) # df1 = np.load('./exp/0103/w_sample/combined/denoised/%s/df4.npy' % img_size) # df2 = np.load('./exp/0103/w_sample/combined/denoised/%s/df7.npy' % img_size) # df3 = np.load('./exp/0103/w_sample/combined/denoised/%s/df10.npy' % img_size) # 有样品zju去噪 孔径10mm by 0105 j1 = np.load('./exp/0105/w_sample/combined/denoised/%s/f_2.npy' % img_size) df1 = np.load('./exp/0105/w_sample/combined/denoised/%s/df4.npy' % img_size) df2 = np.load('./exp/0105/w_sample/combined/denoised/%s/df7.npy' % img_size) df3 = np.load('./exp/0105/w_sample/combined/denoised/%s/df10.npy' % img_size) # 无样品圆孔去噪 孔径10mm by 0105 # j1 = np.load('./exp/0105/wo_sample/combined/denoised/%s/f.npy' % img_size) # df1 = np.load('./exp/0105/wo_sample/combined/denoised/%s/df4.npy' % img_size) # df2 = np.load('./exp/0105/wo_sample/combined/denoised/%s/df7.npy' % img_size) # df3 = np.load('./exp/0105/wo_sample/combined/denoised/%s/df10.npy' % img_size) # 试下仿真的圆孔恢复结果?能很快恢复到较好结果 # j1 = np.load('./exp/1221/wo_sample/combined/%s/test_f.npy' % img_size) # df1 = np.load('./exp/1221/wo_sample/combined/%s/test_df4.npy' % img_size) # df2 = np.load('./exp/1221/wo_sample/combined/%s/test_df7.npy' % img_size) # df3 = np.load('./exp/1221/wo_sample/combined/%s/test_df10.npy' % img_size) # 中心有偏移的仿真圆孔 # j1 = np.load('./exp/1221/wo_sample/combined/%s/shift_test_f.npy' % img_size) # df1 = np.load('./exp/1221/wo_sample/combined/%s/shift_test_df4.npy' % img_size) # df2 = np.load('./exp/1221/wo_sample/combined/%s/shift_test_df7.npy' % img_size) # df3 = np.load('./exp/1221/wo_sample/combined/%s/shift_test_df10.npy' % img_size) # normalization j1 = j1 / np.max(j1) df1 = df1 / np.max(df1) df2 = df2 / np.max(df2) df3 = df3 / np.max(df3) # convert intensities to amplitudes j1 = np.sqrt(j1) df1 = np.sqrt(df1) df2 = np.sqrt(df2) df3 = np.sqrt(df3) plt.subplot(221) plt.imshow(j1) plt.subplot(222) plt.imshow(df1) plt.subplot(223) plt.imshow(df2) plt.subplot(224) plt.imshow(df3) plt.show() # ?是否需要shift?(因为fit中的计算强度是未shift的) j1 = np.fft.ifftshift(j1) df1 = np.fft.ifftshift(df1) df2 = np.fft.ifftshift(df2) df3 = np.fft.ifftshift(df3) j1 = Variable(data_transform(j1).unsqueeze(0)).type(dtype) df1 = Variable(data_transform(df1).unsqueeze(0)).type(dtype) df2 = Variable(data_transform(df2).unsqueeze(0)).type(dtype) df3 = Variable(data_transform(df3).unsqueeze(0)).type(dtype) # ---- generate the defocus term according to the radius of aperture ---- # # 成像透镜半径 1cm?1.5cm?待测 相机像素大小3.69μm?待确认 # full_size_aperture = 24000//3.69 # 目测透镜直径2.4cm # full_size_aperture = 24 full_size_aperture = 34.31 # 42.88 #单位mm # tmp = np.arange(-full_size_aperture/2+full_size_aperture/img_size, full_size_aperture/2+full_size_aperture/img_size, full_size_aperture/img_size) tmp = np.arange( -full_size_aperture / 2, full_size_aperture / 2 + full_size_aperture / (img_size - 1), full_size_aperture / (img_size - 1)) print(tmp.shape) x, y = np.meshgrid(tmp, tmp) diameter = 10 # defocus_term = 2 * (x**2+y**2) - 1 defocus_term = -(x**2 + y**2) * 2 * np.pi / (632.8e-6 * 2 * 200**2) * 50 # 单位mm # defocus_term = -(x**2+y**2) * 2 * np.pi / (632.991e-6 * 8 * 250**2 / full_size_aperture**2) * 4 / full_size_aperture**2# 单位mm defocus_term[np.sqrt(x**2 + y**2) > diameter / 2] = 0 defocus_term = Variable( data_transform(defocus_term).unsqueeze(0)).type(dtype) # 注意,实际实验是圆形口径,而仿真是方形口径,改为圆形相当于增加约束,fit函数需修改 # plt.figure() # plt.imshow(defocus_term) # plt.show() # ---- retrieve phase ---- # # check list: # out相位范围是否正确? # defocus param范围是否正确? mse_loss, mse_loss2, retrieved_phase, out_d1, out_d2, out_d3 = retrieve_exp( j1, df1, df2, df3, defocus_term) retrieved_phase -= np.min(retrieved_phase) out_d1 -= np.min(out_d1) out_d2 -= np.min(out_d2) out_d3 -= np.min(out_d3) # 减去全局最小值好像应该改为减去支持域内最小值?否则全局最小值固定就是0(外圈) min_0 = np.min( retrieved_phase[np.sqrt(x**2 + y**2) < diameter / 2]) # 好像应该以同一个为准 # min_1 = np.min(out_d1[np.sqrt(x**2+y**2)<diameter/2]) # min_2 = np.min(out_d2[np.sqrt(x**2+y**2)<diameter/2]) # min_3 = np.min(out_d3[np.sqrt(x**2+y**2)<diameter/2]) retrieved_phase[np.sqrt(x**2 + y**2) < diameter / 2] -= min_0 # 外面不减里面减 out_d1[np.sqrt(x**2 + y**2) < diameter / 2] -= min_0 out_d2[np.sqrt(x**2 + y**2) < diameter / 2] -= min_0 out_d3[np.sqrt(x**2 + y**2) < diameter / 2] -= min_0 # tmp = np.arange(-full_size_aperture/2+full_size_aperture/img_size, full_size_aperture/2+full_size_aperture/img_size, full_size_aperture/img_size) x, y = np.meshgrid(tmp, tmp) aperture = np.zeros((img_size, img_size)) aperture[np.sqrt(x**2 + y**2) <= diameter / 2] = 1 retrieved_intensity = np.square( np.abs( np.fft.fftshift( np.fft.fft2(aperture * np.exp(1j * retrieved_phase))))) retrieved_df1 = np.square( np.abs(np.fft.fftshift(np.fft.fft2(aperture * np.exp(1j * out_d1))))) retrieved_df2 = np.square( np.abs(np.fft.fftshift(np.fft.fft2(aperture * np.exp(1j * out_d2))))) retrieved_df3 = np.square( np.abs(np.fft.fftshift(np.fft.fft2(aperture * np.exp(1j * out_d3))))) plt.figure() plt.imshow(retrieved_phase, cmap='jet') plt.title('retrieved_phase') plt.axis('off') plt.colorbar() retrieved_intensity = retrieved_intensity / np.max(retrieved_intensity) retrieved_df1 = retrieved_df1 / np.max(retrieved_df1) retrieved_df2 = retrieved_df2 / np.max(retrieved_df2) retrieved_df3 = retrieved_df3 / np.max(retrieved_df3) plt.figure() plt.subplot(241) j1 = j1.cpu().squeeze() j1 = np.fft.fftshift(j1) j1 = np.square(j1) plt.imshow(j1, cmap='jet') plt.title('original_intensity') plt.axis('off') plt.colorbar() plt.subplot(245) plt.imshow(retrieved_intensity, cmap='jet') plt.title('retrieved_intensity') plt.axis('off') plt.colorbar() plt.subplot(242) df1 = df1.cpu().squeeze() df1 = np.fft.fftshift(df1) df1 = np.square(df1) plt.imshow(df1, cmap='jet') plt.title('original_intensity_df1') plt.axis('off') plt.colorbar() plt.subplot(246) plt.imshow(retrieved_df1, cmap='jet') plt.title('retrieved_df1') plt.axis('off') plt.colorbar() plt.subplot(243) df2 = df2.cpu().squeeze() df2 = np.fft.fftshift(df2) df2 = np.square(df2) plt.imshow(df2, cmap='jet') plt.title('original_intensity_df2') plt.axis('off') plt.colorbar() plt.subplot(247) plt.imshow(retrieved_df2, cmap='jet') plt.title('retrieved_df2') plt.axis('off') plt.colorbar() plt.subplot(244) df3 = df3.cpu().squeeze() df3 = np.fft.fftshift(df3) df3 = np.square(df3) plt.imshow(df3, cmap='jet') plt.title('original_intensity_df3') plt.axis('off') plt.colorbar() plt.subplot(248) plt.imshow(retrieved_df3, cmap='jet') plt.title('retrieved_df3') plt.axis('off') plt.colorbar() plt.figure() plt.plot(range(img_size), retrieved_intensity[:, img_size // 2], color='b') plt.plot(range(img_size), j1[:, img_size // 2], color='r') plt.figure() plt.plot(range(img_size), retrieved_df1[:, img_size // 2], color='b') plt.plot(range(img_size), df1[:, img_size // 2], color='r') plt.figure() plt.plot(range(img_size), retrieved_df2[:, img_size // 2], color='b') plt.plot(range(img_size), df2[:, img_size // 2], color='r') plt.figure() plt.plot(range(img_size), retrieved_df3[:, img_size // 2], color='b') plt.plot(range(img_size), df3[:, img_size // 2], color='r') plt.figure() plt.plot(range(len(mse_loss)), mse_loss) plt.figure() plt.plot(range(0, len(mse_loss), 10), mse_loss2) plt.show()