def predict_movies(d): "make movies scrolling through z" d.savedir = Path(d.savedir) (d.savedir/'pimgs/').mkdir(exist_ok=True) (d.savedir/'movie/').mkdir(exist_ok=True) ds = "01" for i in [0,10,100,189]: # for i in [189]: img = imread(f'/lustre/projects/project-broaddus/rawdata/celegans_isbi/Fluo-N3DH-CE/{ds}/t{i:03d}.tif') # lab = imread(f'/lustre/projects/project-broaddus/devseg_data/raw/celegans_isbi/Fluo-N3DH-CE/{ds}_GT/TRA/man_track{i:03d}.tif') # pmin, pmax = np.random.uniform(1,3), np.random.uniform(99.5,99.8) pmin,pmax = 2, 99.6 img = normalize3(img,pmin,pmax).astype(np.float32,copy=False) with torch.no_grad(): pimg = apply_net_tiled(d.net,img[None]) rgb = cat(img, pimg[0], axis=1) rgb = rgb.clip(min=0) # moviesave(normalize3(rgb), d.savedir/f'movie/vert{ds}_{i:03d}.mp4', rate=4) imsave(pimg.astype(np.float16), d.savedir/f'pimgs/pimg{ds}_{i:03d}.tif', compress=9) if False: rgb = i2rgb(img) rgb[...,[0,2]] = pimg[0,...,None][...,[0,0]] rgb[...,1] -= pimg[0] rgb = rgb.clip(min=0) moviesave(normalize3(pimg[0]), d.savedir/f'movie/pimg{i:03d}.mp4', rate=4) ## set i=30 and i=150 to get res022 and res023. moviesave(normalize3(rgb), d.savedir/f'movie/mix{i:03d}.mp4', rate=4)
def datagen(savedir=None): # img = imread(f'/lustre/projects/project-broaddus/rawdata/artifacts/flower.tif')[:10] img = imread( f'/lustre/projects/project-broaddus/denoise_experiments/flower/e02/pred_flower.tif' )[:10] # img = imread(f'/lustre/projects/project-broaddus/rawdata/artifacts/shutterclosed.tif')[0] print(img.shape) # pmin, pmax = np.random.uniform(1,3), np.random.uniform(99.5,99.8) pmin, pmax = 2, 99.6 print(f"pmin = {pmin}; pmax = {pmax}") img = normalize3(img, pmin, pmax).astype(np.float32, copy=False) data = img.reshape((-1, 4, 256, 4, 256)).transpose( (0, 1, 3, 2, 4)).reshape((-1, 1, 256, 256)) # patch_size = (256,256) # slicelist = [] # def random_patch(): # ss = random_slice(img.shape, patch_size) # ## select patches with interesting content. FIXME # while img[ss].mean() < 0.0: # ss = random_slice(img.shape, patch_size) # x = img[ss].copy() # slicelist.append(ss) # ## augment # # noiselevel = 0.2 # # x += np.random.uniform(0,noiselevel,(1,)*3)*np.random.uniform(-1,1,x.shape) # # for d in [0,1,2]: # # if np.random.rand() < 0.5: # # x = np.flip(x,d) # return (x,) # data = np.array([random_patch() for _ in range(24)]) # data = np.load('../../devseg_data/cl_datagen/d003/data.npz') print("data.shape: ", data.shape) #SCZYX if savedir: rgb = collapse2(data[:, :], 'scyx', 's,y,x,c')[..., [0, 0, 0]] rgb = normalize3(rgb) rgb = plotgrid([rgb], 10) io.imsave(savedir / 'data_xy_flower.png', rgb) np.savez_compressed(savedir / 'data_flower.npz', data=data, pmin=pmin, pmax=pmax) # pklsave(slicelist, savedir/'slicelist2.pkl') dg = SimpleNamespace() dg.data = data dg.pmin = pmin dg.pmax = pmax return dg
def receptivefield(net): "calculate and show the receptive field or receptive kernel" def rfweights(m): if type(m) == nn.Conv3d: m.weight.data.fill_(1/75) ## conv kernel 3*5*5 m.bias.data.fill_(0.0) net.apply(rfweights); x0 = np.zeros((128,128,128)); x0[64,64,64]=1; xout = net.cuda()(torch.from_numpy(x0)[None,None].float().cuda()).detach().cpu().numpy() io.imsave(savedir/'recfield_xy.png',normalize3(xout[0,0,64])) io.imsave(savedir/'recfield_xz.png',normalize3(xout[0,0,:,64]))
def fig2_shutterclosed_comparison(): img1 = np.load( '/lustre/projects/project-broaddus/devseg_data/cl_datagen/grid/cele1/epochs_npy/arr_080.npy' ) img2 = np.load( '/lustre/projects/project-broaddus/devseg_data/cl_datagen/grid/cele2/epochs_npy/arr_080.npy' ) img3 = np.load( '/lustre/projects/project-broaddus/devseg_data/cl_datagen/grid/cele3/epochs_npy/arr_080.npy' ) img4 = np.load( '/lustre/projects/project-broaddus/devseg_data/cl_datagen/grid/cele4/epochs_npy/arr_080.npy' ) img5 = np.load( '/lustre/projects/project-broaddus/devseg_data/cl_datagen/grid/cele5/epochs_npy/arr_080.npy' ) ## (N2V, OURS 2class, OURS 3class) , (raw, mask, raw fft, pred, pred fft) , n_samples , channels, y , x # rgb[:,[2,4]] = normalize3(rgb[:,[2,4]], pmin=0, pmax=99.0) # rgb[:,[2,4]] = normalize3(np.log(rgb[:,[2,4]]+1e-7)) rgb[:, [2, 4]] = normalize3(np.log(normalize3(rgb[:, [2, 4]], 0, 99) + 1e-7)) rgb[:, [0, 3]] = normalize3(rgb[:, [0, 3]]) rgb[:, 1] = normalize3(rgb[:, 1]) ## remove channels and pad xy with white rgb = rgb[:, :, :, 0] # rgb = np.pad(rgb,[(0,0),(0,0),(0,0),(0,1),(0,1)],mode='constant',constant_values=1) # plt.figure() # d = np.fft.fftshift(np.fft.fftfreq(256)) # for i,m in enumerate("N2V,OURS 2class,OURS 3class".split(',')): # plt.plot(d,rgb[i,-1].mean((0,1)),label=f'{m} : avg s,y') # plt.plot(d,rgb[i,-1].mean((0,2)),label=f'{m} : avg s,x') # plt.legend() ## reshape to (raw, N2V, ours 2 class, ours 3class) , (real, fft, mask), samples, y, x # rgb = rgb.reshape((15, 4, 256, 256))[] rgb = cat( stak(np.zeros(rgb[0, 0].shape), rgb[0, 0], rgb[0, 2])[None], rgb[:, [1, 3, 4]]) ## models, types, samples, y, x # rgb = collapse2(rgb,'mtsyx','mt,sy,x') # rgb = rgb[[0,1,2,3,4,6,8,9,11,13,14]] # rgb = rgb[[0,1,5,8,3,6,9,2,4,7,10,]] # rgb = collapse2(rgb,'myx','y,mx') # io.imsave(savedir.parent/'shutterclosed_normalized.png',rgb[:64]) np.savez_compressed(savedir.parent / 'fig2_cele.npz', rgb=rgb) return rgb
def datagen(params={}, savedir=None): data = [] times = np.r_[:190] for i in times: img = imread( f'/lustre/projects/project-broaddus/devseg_data/raw/celegans_isbi/Fluo-N3DH-CE/01/t{i:03d}.tif' ) pmin, pmax = np.random.uniform(1, 3), np.random.uniform(99.5, 99.8) img = normalize3(img, pmin, pmax).astype(np.float32, copy=False) slicelist = [] def random_patch(): ss = random_slice(img.shape, (32, 64, 64)) ## select patches with interesting content. 0.02 is chosen by manual inspection. while img[ss].mean() < 0.03: ss = random_slice(img.shape, (32, 64, 64)) x = img[ss].copy() slicelist.append(ss) ## augment # noiselevel = 0.2 # x += np.random.uniform(0,noiselevel,(1,)*3)*np.random.uniform(-1,1,x.shape) # for d in [0,1,2]: # if np.random.rand() < 0.5: # x = np.flip(x,d) return (x, ) data.append([random_patch() for _ in range(10)]) #ts(xys)czyx data = np.array(data) print("data.shape: ", data.shape) if savedir: rgb = collapse2(data[:, :, :, 16], 'tscyx', 'ty,sx,c')[..., [0, 0, 0]] rgb = normalize3(rgb) io.imsave(savedir / 'data_xy_cele.png', rgb) rgb = collapse2(data[:, :, :, :, 32], 'tsczx', 'tz,sx,c')[..., [0, 0, 0]] rgb = normalize3(rgb) io.imsave(savedir / 'data_xz_cele.png', rgb) np.savez_compressed(savedir / 'data_cele.npz', data) pklsave(slicelist, savedir / 'slicelist_cele.pkl') return data
def setup_flower_shutter( rawdata='/lustre/projects/project-broaddus/rawdata/artifacts/flower.tif', savedir='/lustre/projects/project-broaddus/denoise_experiments/flower/e01/flower_test' ): savedir = Path(savedir) wipe_dirs(savedir) init_dirs(savedir) img = imread(rawdata) pmin, pmax = 2, 99.6 print(f"pmin = {pmin}; pmax = {pmax}") img = normalize3(img, pmin, pmax).astype(np.float32, copy=False) data = img.reshape((-1, 4, 256, 4, 256)).transpose( (0, 1, 3, 2, 4)).reshape((-1, 1, 256, 256)) d = SimpleNamespace() d.net = torch_models.Unet2_2d(16, [[1], [1]], finallayer=nn.ReLU) d.net.load_state_dict( torch.load( '/lustre/projects/project-broaddus/denoise_experiments/flower/models/net_randinit.pt' )) d.net.apply(init_weights) d.savedir = savedir d.x1_all = torch.from_numpy(data).float() return d
def setup_flower_shutter(rawdata, savedir): savedir = Path(savedir) wipe_dirs(savedir) init_dirs(savedir) data = imread(rawdata) data = normalize3(data, 2, 99.6) ## normalize across all dims? d = SimpleNamespace() d.net = torch_models.Unet2_2d(16, [[1], [1]], finallayer=nn.ReLU) d.net.load_state_dict( torch.load( '/lustre/projects/project-broaddus/denoise_experiments/flower/models/net_randinit.pt' )) # d.net.apply(init_weights); d.savedir = savedir d.xs = torch.from_numpy(data).float() d.xs = d.xs.reshape(100, 4, 256, 4, 256, 1).permute(0, 1, 3, 5, 2, 4) #.reshape((-1,256,256)) d.ys = d.xs.mean(0) io.imsave(savedir / 'xs.png', collapse2(d.xs[0, :, :, 0].numpy(), "12yx", "1y,2x")) io.imsave(savedir / 'ys.png', collapse2(d.ys[:, :, 0].numpy(), "12yx", "1y,2x")) d.cuda = False return d
def load_cele(): raw = np.array([ imread( f"/lustre/projects/project-broaddus/rawdata/celegans_isbi/Fluo-N3DH-CE/01/t{i:03d}.tif" ) for i in [0, 10, 100, 189] ]) raw = normalize3(raw, 2, 99.6) n2v = np.array([ imread( f"/projects/project-broaddus/denoise_experiments/cele/e01/cele1/pimgs/pimg01_{i:03d}.tif" ) for i in [0, 10, 100, 189] ]) n2v = n2v[:, 0] n2v2 = np.array([ imread( f"/projects/project-broaddus/denoise_experiments/cele/e01/cele3/pimgs/pimg01_{i:03d}.tif" ) for i in [0, 10, 100, 189] ]) n2v2 = n2v2[:, 0] nlm = np.array([ imread( f"/projects/project-broaddus/denoise_experiments/cele/e01/nlm/denoised{i:03d}.tif" ) for i in [0, 10, 100, 189] ]) dat = SimpleNamespace(raw=raw, n2v2=n2v2, nlm=nlm, n2v=n2v) return dat
def predict_full(): "make movies scrolling through z" net = torch_models.Unet2_2d(16, [[1], [1]], finallayer=nn.ReLU).cuda() # Rob Jenkin (Alana) 540 692 0113 net.load_state_dict( torch.load( '/lustre/projects/project-broaddus/denoise/flower/e01/flower3_6/models/net600.pt' )) img = imread( f'/lustre/projects/project-broaddus/devseg_data/raw/artifacts/flower.tif' ) # pmin, pmax = np.random.uniform(1,3), np.random.uniform(99.5,99.8) pmin, pmax = 2, 99.6 img = normalize3(img, pmin, pmax, axs=(1, 2)).astype(np.float32, copy=False) pimg = [] for x in img: # x = torch.from_numpy(x).cuda() # x = net(x[None]) x = apply_net_tiled(net, x[None]) pimg.append(x) pimg = np.array(pimg) # return img, net, pimg # pimg = apply_net_tiled(net,img[:,None]) imsave(pimg, savedir / f'pred_flower.tif')
def make_visual_table_cele(dat, outfile=None): names = "RAW NLM N2V N2V2".split(' ') rgb = stak(dat.raw[0], dat.nlm[0], dat.n2v[0], dat.n2v2[0]) rgb = normalize3(rgb) z, y, x = 14, 256, 256 ## top left pixel location rgb = rgb[:, z, y:y + 256, x:x + 256].transpose((1, 0, 2)).reshape( (256, -1)) io.imsave(outfile, rgb)
def make_visual_table(dat, outfile=None): names = "RAW NLM BM3D N2V N2V2 N2GT GT".split(' ') rgb = stak(dat.all[0], dat.nlm[0], dat.bm3d[0], dat.e01.data[0, 0], dat.e01.data[7, 0], dat.n2gt[0], dat.gt) rgb = normalize3(rgb) y, x = 256, 256 ## top left pixel location rgb = rgb[:, y:y + 256, x:x + 256].transpose((1, 0, 2)).reshape((256, -1)) print(rgb.shape) io.imsave(outfile, rgb)
def e02_fig2_flower(): img1 = np.load( '/lustre/projects/project-broaddus/denoise/flower/e02/flower1_1/epochs_npy/arr_400.npy' ) img2 = np.load( '/lustre/projects/project-broaddus/denoise/flower/e02/flower1_2/epochs_npy/arr_400.npy' ) img3 = np.load( '/lustre/projects/project-broaddus/denoise/flower/e02/flower1_3/epochs_npy/arr_400.npy' ) img4 = np.load( '/lustre/projects/project-broaddus/denoise/flower/e02/flower1_4/epochs_npy/arr_400.npy' ) img5 = np.load( '/lustre/projects/project-broaddus/denoise/flower/e02/flower1_5/epochs_npy/arr_400.npy' ) img6 = np.load( '/lustre/projects/project-broaddus/denoise/flower/e02/flower1_6/epochs_npy/arr_400.npy' ) rgb = stak(img1, img2, img3, img4, img5, img6) ## normalize fft and real space separately rgb[:, [2, 4]] = normalize3(np.log(normalize3(rgb[:, [2, 4]], 0, 99) + 1e-7)) rgb[:, [0, 3]] = normalize3(rgb[:, [0, 3]]) rgb[:, 1] = normalize3(rgb[:, 1]) ## remove channels and pad xy with white rgb = rgb[:, :, :, 0] # rgb = np.pad(rgb,[(0,0),(0,0),(0,0),(0,1),(0,1)],mode='constant',constant_values=1) ## reshape to (raw, N2V, ours 2 class, ours 3class) , (real, fft, mask), samples, y, x rgb = cat( stak(np.zeros(rgb[0, 0].shape), rgb[0, 0], rgb[0, 2])[None], rgb[:, [1, 3, 4]]) np.savez_compressed( '/lustre/projects/project-broaddus/denoise/flower/e02/e02_fig2_flower.npz', rgb=rgb) return rgb
def optimize_nlm(): img = imread(flowerdata) img = normalize3(img,2,99.6) gt = img.mean(0) def obj(sigma): res = gputools.denoise.nlm2(img[1],sigma) return ((gt-res)**2).mean() # for s in [0.01,0.1,0.5,0.9,2.0]: # print(obj(s)) print(minimize_scalar(obj, bracket=(0.1,0.5,0.9)))
def predict_on_full_2d_stack(rawdata, savedir, weights): savedir = Path(savedir) net = torch_models.Unet2_2d(16, [[1], [1]], finallayer=nn.ReLU).cuda() net.load_state_dict(torch.load(weights)) img = imread(rawdata) pmin, pmax = 2, 99.6 img = normalize3(img, pmin, pmax, axs=(1, 2)).astype(np.float32, copy=False) pimg = [] for x in img: x = apply_net_tiled_2d(net, x[None]) pimg.append(x) pimg = np.array(pimg) imsave(pimg.astype(np.float16), savedir / f'pred.tif', compress=9)
def optimize_bm3d(): img = imread(flowerdata) img = normalize3(img,2,99.6) gt = img.mean(0) bm3d = "/projects/project-broaddus/comparison_methods/bm3d/build/bm3d" tmp = flowerdir + 'bm3d/eg0.tif' imsave(img[0], tmp) outname = flowerdir + 'bm3d/res0.tif' def obj(sigma): run(f"{bm3d} {tmp} {sigma} {outname}",shell=True) res = imread(outname) return ((gt-res)**2).mean() print(minimize_scalar(obj, bracket=(0.01, 0.3, 0.5)))
def nlm_3d_cele(savedir, sigma=0.1, **kwargs): # dir = "nlm" # savedir = denoise_experiments / f'flower/e01/{dir}/' celedata = '/lustre/projects/project-broaddus/rawdata/celegans_isbi/Fluo-N3DH-CE/01/' savedir = Path(savedir); savedir.mkdir(exist_ok=True,parents=True) # img = imread(rawdata_dir / 'artifacts/flower.tif') for i in [0,10,100,189]: img = imread(celedata + f"t{i:03d}.tif") pmin, pmax = 2, 99.6 img = normalize3(img,pmin,pmax,axs=(1,2)).astype(np.float32,copy=False) ## gputools.denoise.nlm2(data, sigma, size_filter=2, size_search=3) ## for noise level of sigma_0, choose sigma = 1.5*sigma_0 pimg = gputools.denoise.nlm3(img,sigma,**kwargs) imsave(pimg, savedir / f'denoised{i:03d}.tif')
def receptivefield2d(net, kern=(5, 5)): "calculate and show the receptive field or receptive kernel" def rfweights(m): if type(m) == nn.Conv2d: m.weight.data.fill_(1 / np.prod(kern)) ## conv kernel 3*5*5 m.bias.data.fill_(0.0) net.apply(rfweights) x0 = np.zeros((256, 256)) x0[128, 128] = 1 xout = net.cuda()( torch.from_numpy(x0)[None, None].float().cuda()).detach().cpu().numpy() io.imsave(savedir / 'recfield_xy.png', normalize3(xout[0, 0]))
def nlm_2d_cele_just189(): name = '/projects/project-broaddus/rawdata/celegans_isbi/Fluo-N3DH-CE/01/t189.tif' img = imread(name) img = normalize3(img, 2, 99.6) img = img[22] r_sigma = np.linspace(.3, .6, 10) + 0.36 r_size_filer = [3, 4, 5] #[1,2,3,4,5,6] r_size_serach = [6, 7, 8, 9, 10] #[5,10,15,20] count = 0 for p1, p2, p3 in itertools.product(r_sigma, r_size_filer, r_size_serach): print(count) img2 = gputools.denoise.nlm2(img, p1, size_filter=p2, size_search=p3) name2 = f"/projects/project-broaddus/denoise_experiments/cele/e01/nlm2_2d/t{count:03d}.tif" save(img2, name2) count += 1
def bm3d_2d(rawdata, savedir, **kwargs): img = imread(rawdata) img = normalize3(img,2,99.6) bm3d = "/projects/project-broaddus/comparison_methods/bm3d/build/bm3d" savedir = Path(savedir); savedir.mkdir(exist_ok=True,parents=True) tmpdir = savedir / 'tmp/'; tmpdir.mkdir(exist_ok=True,parents=True) sigma = 0.15488 ## optimized vs GT for i in range(100): tmpname = tmpdir / f"img{i:03d}.tif" outname = savedir / f"img{i:03d}.tif" if not tmpname.exists(): imsave(img[i], tmpname) run(f"{bm3d} {tmpname} {sigma} {outname}",shell=True)
def nlm_2d(rawdata, savedir, **kwargs): # dir = "nlm" # savedir = denoise_experiments / f'flower/e01/{dir}/' savedir = Path(savedir) savedir.mkdir(exist_ok=True,parents=True) # img = imread(rawdata_dir / 'artifacts/flower.tif') img = imread(rawdata) pmin, pmax = 2, 99.6 img = normalize3(img,pmin,pmax,axs=(1,2)).astype(np.float32,copy=False) pimg = [] for x in img: ## gputools.denoise.nlm2(data, sigma, size_filter=2, size_search=3) ## for noise level of sigma_0, choose sigma = 1.5*sigma_0 sigma = 0.1826499502297115 ## obtained through optimization vs GT x = gputools.denoise.nlm2(x,sigma,**kwargs) pimg.append(x) pimg = np.array(pimg) imsave(pimg, savedir/f'denoised.tif')
def nlmeval(nlm_vals, outfile): flower_all = imread(rawdata_dir / 'artifacts/flower.tif') flower_all = normalize3(flower_all, 2, 99.6) flower_gt = flower_all.mean(0) nlm = np.array([ imread( f"/projects/project-broaddus/denoise_experiments/flower/e01/nlm/{n:04d}/denoised.tif" ) for n in nlm_vals ]) table = [] for i in range(nlm.shape[0]): table.append(eval_single(flower_gt, nlm[i], nlm_vals[i])) header = ['name', 'mse', 'psnr', 'ssim'] with open(outfile, "w", newline="\n") as f: writer = csv.writer(f) writer.writerows([header] + table)
def bm4d(): # call = '/Applications/MATLAB.app/bin/matlab -nosplash -nodesktop –nojvm -r "denoise_file($FILENAME,$SIGMA),quit"' name = '/projects/project-broaddus/rawdata/celegans_isbi/Fluo-N3DH-CE/01/t189.tif' img = load(name) img = normalize3(img, 2, 99.6) img = (img).astype(np.float32) ## fail: np.float64,np.float32,uint32,uint16,uint8 name2 = '/projects/project-broaddus/denoise_experiments/data/celegans_isbi/Fluo-N3DH-CE/01/t189_f32.tif' imsave(img, name2) sigmas = [0.04, 0.05, 0.07, 0.1, 0.14, 0.19, 0.26, 0.35, 0.49, 0.67] sigmas = [0.60, 0.7, 0.8, 0.9] for s in sigmas: # call = f'/sw/apps/matlab/current/bin/matlab -nosplash -nodesktop -nojvm -r "denoise_file({name2},{sigma}),quit"' call = f""" cd /projects/project-broaddus/denoise_experiments/data/bm4d /sw/apps/matlab/current/bin/matlab -nosplash -nodesktop -nojvm -r "denoise_file(\'{name2}\',{s}),quit" """ run(call, shell=True)
def nlm_3d_cele_just189(n): name = '/projects/project-broaddus/rawdata/celegans_isbi/Fluo-N3DH-CE/01/t189.tif' img = imread(name) img = normalize3(img, 2, 99.6) r_sigma = np.linspace(.3, .6, 10) r_size_filer = [1, 2, 3] #,4,5,6] r_size_serach = [1, 5, 10, 15] #[5,10,15,20] count = 0 for p1, p2, p3 in itertools.product(r_sigma, r_size_filer, r_size_serach): # count = n # p1,p2,p3 = list(itertools.product(r_sigma,r_size_filer,r_size_serach))[n] print(count) img2 = gputools.denoise.nlm3(img, p1, size_filter=p2, size_search=p3) name2 = f"/projects/project-broaddus/denoise_experiments/cele/e01/nlm2_3d/t{count:03d}.tif" save(img2, name2) save(img2[22], name2.replace("nlm2_3d/", "nlm2_3d_s22/")) count += 1
def load_prediction_and_eval_metrics__generic(rawdata, loaddir): raw_all = imread(rawdata) raw_all = normalize3(raw_all, 2, 99.6) gt = raw_all.mean(0) ## deal with heterogeneous file names loaddir = Path(loaddir) if (loaddir / 'denoised.tif').exists(): img = imread(loaddir / 'denoised.tif') elif (loaddir / 'pred.tif').exists(): img = imread(loaddir / 'pred.tif') elif (loaddir / 'img000.tif').exists(): img = np.array( [imread(loaddir / f'img{n:03d}.tif') for n in range(100)]) ## deal with singleton channels if img.shape[1] == 1: img = img[:, 0] met = eval_single_metrics(gt, img) header = ['mse', 'psnr', 'ssim'] writecsv([header, met], loaddir / 'table.csv')
def bm3d_3d_cele_just189(): name = '/projects/project-broaddus/rawdata/celegans_isbi/Fluo-N3DH-CE/01/t189.tif' name2 = name.replace("rawdata/celegans_isbi/", "denoise_experiments/cele/e01/bm3d2/tmp/") img = imread(name) img = normalize3(img, 2, 99.6) img = img[22] imsave(img, name2) bm3d = "/projects/project-broaddus/comparison_methods/bm3d/build/bm3d" r_sigma = np.arange(10) * 0.6 + 0.3 count = 0 for sigma in r_sigma: print(count) name3 = Path(name2).parent / f"out_{count}.tif" run(f"{bm3d} {name2} {sigma} {name3}", shell=True) count += 1 img = load(name3) print(img.max(), img.dtype) m = np.isnan(img) print(f'nan: {m.sum()}') img[m] = 0 print(f'nan: {m.sum()}') save(img, name3)
def load_shutter(): ## load the flower dataset and build the GT raw_all = imread(rawdata_dir / 'artifacts/shutterclosed.tif') raw_all = normalize3(raw_all, 2, 99.6) raw_gt = raw_all.mean(0) # raw_gt_patches = raw_gt.reshape((4,256,4,256)).transpose((0,2,1,3)).reshape((16,256,256)) # raw_gt_patches = raw_gt_patches[[0,3,5,12]] ## load the predictions from single-phase models (600th epoch) img0 = imread(experiments_dir / 'shutter/e01/mask00/pred.tif') # n2v img1 = imread(experiments_dir / 'shutter/e01/mask01/pred.tif') img2 = imread(experiments_dir / 'shutter/e01/mask02/pred.tif') img3 = imread(experiments_dir / 'shutter/e01/mask03/pred.tif') img4 = imread(experiments_dir / 'shutter/e01/mask04/pred.tif') img5 = imread(experiments_dir / 'shutter/e01/mask05/pred.tif') img6 = imread(experiments_dir / 'shutter/e01/mask06/pred.tif') img7 = imread(experiments_dir / 'shutter/e01/mask07/pred.tif') img8 = imread(experiments_dir / 'shutter/e01/mask08/pred.tif') names = [ "n2v", "1x", "2x", "3x", "4x", "5x", "6x", "7x", "8x", ] data = stak( img0, img1, img2, img3, img4, img5, img6, img7, img8, ) # data[:,[2,4]] = normalize3(np.log(normalize3(data[:,[2,4]],0,99)+1e-7)) ## k-space channels # data[:,[0,3]] = normalize3(data[:,[0,3]]) ## real space channels # data[:,1] = normalize3(data[:,1]) ## mask channel ? ## remove channels dim data = data[:, :, 0] ## move raw to front. reshape to ?,4,256,256 # data = cat(stak(np.zeros(data[0,0].shape),data[0,0],data[0,2])[None],data[:,[1,3,4]]) ## put the trials in a sensible order # perm = [0, 2, 3, 1, 5, 6, 7, 8, 9, 10, 4,] # data = data[perm] # names = list(np.array(names)[perm]) nlm = imread( "/projects/project-broaddus/denoise_experiments/shutter/e01/nlm/denoised.tif" ) bm3d = np.array([ imread(x) for x in sorted( glob( "/projects/project-broaddus/denoise_experiments/shutter/e01/bm3d/*.tif" )) ]) n2gt = imread( "/projects/project-broaddus/denoise_experiments/shutter/e01/n2gt2/pred.tif" ) n2gt = n2gt[:, 0] ## get rid of singleton channel e01 = SimpleNamespace(data=data, names=names) dat = SimpleNamespace(gt=raw_gt, e01=e01, all=raw_all, bm3d=bm3d, n2gt=n2gt, nlm=nlm) #e02=e02) return dat
def load_flower(): ## load the flower dataset and build the GT flower_all = imread(rawdata_dir / 'artifacts/flower.tif') flower_all = normalize3(flower_all, 2, 99.6) flower_gt = flower_all.mean(0) ## load the predictions from single-phase models (600th epoch) img0 = np.array([ imread(experiments_dir / f'flower/e01/mask00_{n}/pred.tif') for n in range(5) ]) img1 = np.array([ imread(experiments_dir / f'flower/e01/mask01_{n}/pred.tif') for n in range(5) ]) img2 = np.array([ imread(experiments_dir / f'flower/e01/mask02_{n}/pred.tif') for n in range(5) ]) img3 = np.array([ imread(experiments_dir / f'flower/e01/mask03_{n}/pred.tif') for n in range(5) ]) img4 = np.array([ imread(experiments_dir / f'flower/e01/mask04_{n}/pred.tif') for n in range(5) ]) img5 = np.array([ imread(experiments_dir / f'flower/e01/mask05_{n}/pred.tif') for n in range(5) ]) img6 = np.array([ imread(experiments_dir / f'flower/e01/mask06_{n}/pred.tif') for n in range(5) ]) img7 = np.array([ imread(experiments_dir / f'flower/e01/mask07_{n}/pred.tif') for n in range(5) ]) img8 = np.array([ imread(experiments_dir / f'flower/e01/mask08_{n}/pred.tif') for n in range(5) ]) names = "N2V 1x 2x 3x 4x 5x 6x 7x 8x".split(' ') if False: ## load the predictions from single-phase models (600th epoch) img6 = imread(experiments_dir / 'flower/e01/flower3_6/pred_flower.tif') # 0 n2v img7 = imread(experiments_dir / 'flower/e01/flower3_7/pred_flower.tif') # 1 xox img8 = imread(experiments_dir / 'flower/e01/flower3_8/pred_flower.tif') # 2 plus img9 = imread(experiments_dir / 'flower/e01/flower3_9/pred_flower.tif') # 3 bigplus img10 = imread(experiments_dir / 'flower/e01/flower3_10/pred_flower.tif') # 4 8xo8x img11 = imread(experiments_dir / 'flower/e01/flower3_11/pred_flower.tif') # 5 xxoxx img12 = imread(experiments_dir / 'flower/e01/flower3_12/pred_flower.tif') # 6 xxxoxxx img13 = imread(experiments_dir / 'flower/e01/flower3_13/pred_flower.tif') # 7 xxxxoxxxx img14 = imread( experiments_dir / 'flower/e01/flower3_14/pred_flower.tif') # 8 xxxxxoxxxxx img15 = imread( experiments_dir / 'flower/e01/flower3_15/pred_flower.tif') # 9 xxxxxxoxxxxxx img16 = imread( experiments_dir / 'flower/e01/flower3_16/pred_flower.tif') # 10 xxxxxxxoxxxxxxx names = [ "n2v", "xox", "plus", "bigplus", "8xo8x", "xxoxx", "xxxoxxx", "xxxxoxxxx", "xxxxxoxxxxx", "xxxxxxoxxxxxx", "xxxxxxxoxxxxxxx", ] data = stak( img0, img1, img2, img3, img4, img5, img6, img7, img8, ) # data = stak(img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16,) # data[:,[2,4]] = normalize3(np.log(normalize3(data[:,[2,4]],0,99)+1e-7)) ## k-space channels # data[:,[0,3]] = normalize3(data[:,[0,3]]) ## real space channels # data[:,1] = normalize3(data[:,1]) ## mask channel ? ## remove channels dim data = data[:, :, 0] ## move raw to front. reshape to ?,4,256,256 # data = cat(stak(np.zeros(data[0,0].shape),data[0,0],data[0,2])[None],data[:,[1,3,4]]) ## put the trials in a sensible order # perm = [0, 2, 3, 1, 5, 6, 7, 8, 9, 10, 4,] # data = data[perm] # names = list(np.array(names)[perm]) # nlm_vals = [5,10,50,100,200,500] nlm = np.array([ imread( f"/projects/project-broaddus/denoise_experiments/flower/e01/nlm/0010/denoised.tif" ) for n in nlm_vals ]) bm3d = np.array([ imread(x) for x in sorted( glob( "/projects/project-broaddus/denoise_experiments/flower/e01/bm3d/*.tif" )) ]) n2gt = imread( "/projects/project-broaddus/denoise_experiments/flower/e01/n2gt2/pred.tif" ) n2gt = n2gt[:, 0] ## get rid of singleton channel e01 = SimpleNamespace(data=data, names=names) dat = SimpleNamespace(gt=flower_gt, e01=e01, all=flower_all, bm3d=bm3d, n2gt=n2gt, nlm=nlm) #e02=e02) return dat
def train(d, ta=None, end_epoch=300, already_on_cuda=False): if ta is None: ta = init_training_artifacts() ## setup const variables necessary for training batch_size = 4 inds = np.arange(0, d.xs.shape[0]) patch_size = d.xs.shape[4:] # xs = d.xs.reshape((100,4,256,4,256)).permute((0,1,3,2,4)) #.reshape((-1,256,256)) # ys = d.xs.mean(0).reshape((4,256,4,256)).permute((0,2,1,3)) d.ws = torch.ones(d.xs.shape).float() ## set up variables for monitoring training # d.example_xs = d.xs[inds[::floor(np.sqrt(len(inds)))]].clone() d.example_xs = d.xs[[0, 3, 5, 12], 0, 0].reshape(-1, 1, 256, 256).clone().cpu() d.xs_fft = torch.fft( (d.example_xs - d.example_xs.mean())[..., None][..., [0, 0]], 2).norm(p=2, dim=-1) d.xs_fft = torch.from_numpy(np.fft.fftshift(d.xs_fft, axes=(-1, -2))) lossdist = torch.zeros(d.xs.shape[0]) - 2 ## move vars to gpu # if d.cuda is False: d.net = d.net.cuda() d.xs = d.xs.cuda() d.ys = d.ys.cuda() d.xs_fft = d.xs_fft.cuda() d.example_xs = d.example_xs.cuda() d.ws = d.ws.cuda() ## initialize optimizer (must be done after moving data to gpu ?) opt = torch.optim.Adam(d.net.parameters(), lr=2e-4) plt.figure() for e in range(ta.e, end_epoch + 1): ta.e = e np.random.shuffle(inds) ta.lossdists.append(lossdist.numpy().copy()) lossdist[...] = -1 print(f"\r epoch {e}", end="") for b in range(ceil(d.xs.shape[0] / batch_size)): idxs = inds[b * batch_size:(b + 1) * batch_size] x1 = d.xs[idxs] w1 = d.ws[idxs] # y1 = d.ys[idxs] x1 = x1.reshape(-1, 1, 256, 256) y1p = d.net(x1) # x1 = x1.reshape(-1,4,4,256,256) # y1p = y1p.reshape(-1,4,4,256,256) y1p = y1p.reshape(4, 4, 4, 1, 256, 256) # ipdb.set_trace() dims = (1, 2, 3, 4, 5) ## all dims except batch # ipdb.set_trace() loss_per_patch = ((y1p - d.ys)**2).mean(dims) # loss_per_patch = (w1 * torch.abs(y1p-y1t)).sum(dims) / w1.sum(dims) #+ 1e-3*(y1p.mean(dims)).abs() # loss_per_patch = (w1 * -(y1t*torch.log(y1p + 1e-7) + (1-y1t)*torch.log((1-y1p) + 1e-7))).sum(dims) / w1.sum(dims) #+ 1e-2*(y1p.mean(dims)).abs() lossdist[idxs] = loss_per_patch.detach().cpu() loss = loss_per_patch.mean() ta.losses.append(float(loss)) opt.zero_grad() loss.backward() opt.step() ## predict on examples and save each epoch if e % 10 == 0: with torch.no_grad(): example_yp = d.net(d.example_xs) ## compute fft from predictions yp_fft = torch.fft( (example_yp - example_yp.mean())[..., None][..., [0, 0]], 2).norm(p=2, dim=-1) #.cpu().detach().numpy() ## shift frequency domain s.t. zer freq is at center of array yp_fft = torch.from_numpy( np.fft.fftshift(yp_fft.cpu(), axes=(-1, -2))).cuda() ## stack (real space, -weights-, real fft, predictions, and prediction fft) along a new dimension rgb = torch.stack([d.example_xs, d.xs_fft, example_yp, yp_fft], 0).cpu().detach().numpy() arr = rgb.copy() ## first normalize each type to [0,1] independently rgb = normalize3(rgb, axs=(1, 2, 3, 4)) # dims=type,samples,channels,y,x ## then normalize fft's and real-space dims separately rgb[[1, 3]] = normalize3(rgb[[1, 3]], pmin=0, pmax=99.0, axs=(1, 2, 3, 4)) ## remove channels and permute into a 2D image rgb = collapse2(rgb[:, :, 0], 'tsyx', 'sy,tx') with warnings.catch_warnings(): warnings.simplefilter("ignore") if e % 10 == 0: io.imsave(d.savedir / f'epochs/rgb_{e:03d}.png', rgb) if e % 100 == 0: np.save(d.savedir / f'epochs_npy/arr_{e:03d}.npy', arr) ## plot loss batches_per_epoch = ceil(d.xs.shape[0] / batch_size) epochs = np.arange(len(ta.losses)) / batches_per_epoch plt.clf() plt.plot(epochs, ta.losses) # plt.ylim(np.mean(ta.losses)-3*np.std(ta.losses),np.mean(ta.losses)+3*np.std(ta.losses)) plt.yscale('log') plt.xlabel(f'1 epoch = {batches_per_epoch} batches') plt.savefig(d.savedir / f'loss.png', dpi=300) ## save model weights if e % 100 == 0: torch.save(d.net.state_dict(), d.savedir / f'models/net{e:03d}.pt') pklsave(ta.losses, d.savedir / f'losses.pkl') torch.save(d.net.state_dict(), d.savedir / f'models/net{ta.e:03d}.pt') return ta
def train(d, ta=None, end_epoch=300, mask_shape=[1, 2, 3, 4]): if ta is None: ta = init_training_artifacts() ## set up const variables necessary for training batch_size = 4 inds = np.arange(0, d.x1_all.shape[0]) patch_size = d.x1_all.shape[2:] d.w1_all = torch.ones(d.x1_all.shape).float() ## set up variables for monitoring training # d.eg_xs = d.x1_all[inds[::floor(np.sqrt(len(inds)))]].clone() d.eg_xs = d.x1_all[[0, 3, 5, 12]].clone() d.xs_fft = torch.fft((d.eg_xs - d.eg_xs.mean())[..., None][..., [0, 0]], 2).norm(p=2, dim=-1) d.xs_fft = torch.from_numpy(np.fft.fftshift(d.xs_fft, axes=(-1, -2))) lossdist = torch.zeros(d.x1_all.shape[0]) - 2 ## move everything to cuda d.net = d.net.cuda() d.x1_all = d.x1_all.cuda() d.w1_all = d.w1_all.cuda() d.xs_fft = d.xs_fft.cuda() d.eg_xs = d.eg_xs.cuda() opt = torch.optim.Adam(d.net.parameters(), lr=2e-5) plt.figure() for e in range(ta.e, end_epoch + 1): ta.e = e np.random.shuffle(inds) lossdist[...] = -1 print(f"\r epoch {e}", end="") for b in range(ceil(d.x1_all.shape[0] / batch_size)): idxs = inds[b * batch_size:(b + 1) * batch_size] x1 = d.x1_all[idxs] #.cuda() w1 = d.w1_all[idxs] #.cuda() def random_pixel_mask(): n = int(np.prod(patch_size) * 0.02) x_inds = np.random.randint(0, patch_size[1], n) y_inds = np.random.randint(0, patch_size[0], n) # z_inds = np.random.randint(0,32,64*64*1) ma = np.zeros(patch_size) ma[y_inds, x_inds] = 2 return ma def sparse_3set_mask(): "build random mask for small number of central pixels" n = int(np.prod(patch_size) * 0.02) x_inds = np.random.randint(0, patch_size[1], n) y_inds = np.random.randint(0, patch_size[0], n) ma = np.zeros(patch_size) # ma = binary_dilation(ma) for i in mask_shape: m = x_inds - i >= 0 ma[y_inds[m], x_inds[m] - i] = 1 m = x_inds + i < patch_size[1] ma[y_inds[m], x_inds[m] + i] = 1 # for i in [1]: # m = y_inds-i >= 0; ma[y_inds[m]-i,x_inds[m]] = 1 # m = y_inds+i < patch_size[0]; ma[y_inds[m]+i,x_inds[m]] = 1 ma = ma.astype(np.uint8) ma[y_inds, x_inds] = 2 return ma def checkerboard_mask(): ma = np.indices(patch_size).transpose((1, 2, 0)) ma = np.floor(ma / (1, 256)).sum(-1) % 2 == 0 ma = 2 * ma if e % 2 == 1: ma = 2 - ma return ma ma = sparse_3set_mask() # ipdb.set_trace() # return ma ## apply mask to input w1[:, :] = torch.from_numpy(ma.astype(np.float)).cuda() x1_damaged = x1.clone() x1_damaged[w1 > 0] = torch.rand(x1.shape).cuda()[w1 > 0] y1p = d.net(x1_damaged) dims = (1, 2, 3) ## all dims except batch if False: dx = 0.15 * torch.abs(y1p[:, :, :, 1:] - y1p[:, :, :, :-1]) dy = 0.15 * torch.abs(y1p[:, :, 1:] - y1p[:, :, :-1]) dy = 0.25 * torch.abs(y1p[:, :, :, 1:] - y1p[:, :, :, :-1]) dz = 0.05 * torch.abs(y1p[:, :, 1:] - y1p[:, :, :-1]) c0, c1, c2 = 0.0, 0.15, 1.0 potential = 2e2 * ( (y1p - c0)**2 * (y1p - c2)**2) ## rough locations for three classes resid = torch.abs(y1p - x1)**2 loss_per_patch = resid.mean(dims) + dx.mean( dims ) #+ dy.mean(dims) + dz.mean(dims) #+ potential.mean(dims) tm = (w1 == 2).float() ## target mask loss_per_patch = (tm * torch.abs(y1p - x1)**2).sum(dims) / tm.sum( dims) # + dx.mean(dims) + dy.mean(dims) #+ dz.mean(dims) # ipdb.set_trace() # loss_per_patch = (w1 * torch.abs(y1p-y1t)).sum(dims) / w1.sum(dims) #+ 1e-3*(y1p.mean(dims)).abs() # loss_per_patch = (w1 * -(y1t*torch.log(y1p + 1e-7) + (1-y1t)*torch.log((1-y1p) + 1e-7))).sum(dims) / w1.sum(dims) #+ 1e-2*(y1p.mean(dims)).abs() lossdist[idxs] = loss_per_patch.detach().cpu() loss = loss_per_patch.mean() ta.losses.append(float(loss)) opt.zero_grad() loss.backward() opt.step() ## predict on examples and save predictions as images with torch.no_grad(): example_yp = d.net(d.eg_xs) # d.xs_fft = d.xs_fft/d.xs_fft.max() yp_fft = torch.fft( (example_yp - example_yp.mean())[..., None][..., [0, 0]], 2).norm(p=2, dim=-1) #.cpu().detach().numpy() yp_fft = torch.from_numpy( np.fft.fftshift(yp_fft.cpu(), axes=(-1, -2))).cuda() # yp_fft = yp_fft/yp_fft.max() rgb = torch.stack([ d.eg_xs, w1[[0] * len(d.eg_xs)] / 2, d.xs_fft, example_yp, yp_fft ], 0).cpu().detach().numpy() arr = rgb.copy() # type,samples,channels,y,x rgb = normalize3(rgb, axs=(1, 2, 3, 4)) rgb[[2, 4]] = normalize3(rgb[[2, 4]], pmin=0, pmax=99.0, axs=(1, 2, 3, 4)) # return rgb # remove channels and permute rgb = collapse2(rgb[:, :, 0], 'tsyx', 'sy,tx') # arr = collapse2(arr[:,:,0],'tsyx','sy,tx') with warnings.catch_warnings(): warnings.simplefilter("ignore") if e % 10 == 0: io.imsave(d.savedir / f'epochs/rgb_{e:03d}.png', rgb) if e % 100 == 0: np.save(d.savedir / f'epochs_npy/arr_{e:03d}.npy', arr) ## plot the loss after each epoch ta.lossdists.append(lossdist.numpy().copy()) batches_per_epoch = ceil(d.x1_all.shape[0] / batch_size) x_axis = np.arange(len(ta.losses)) / batches_per_epoch plt.clf() plt.plot(x_axis, ta.losses) # plt.ylim(np.mean(ta.losses)-3*np.std(ta.losses),np.mean(ta.losses)+3*np.std(ta.losses)) plt.yscale('log') plt.xlabel(f'1 epoch = {batches_per_epoch} batches') plt.savefig(d.savedir / f'loss.png', dpi=300) ## and save the model state if e % 50 == 0: torch.save(d.net.state_dict(), d.savedir / f'models/net{e:03d}.pt') pklsave(ta.losses, d.savedir / f'losses.pkl') torch.save(d.net.state_dict(), d.savedir / f'models/net{ta.e:03d}.pt') return ta
def train(d, ta=None, end_epoch=301): if ta is None: ta = init_training_artifacts() batch_size = 4 inds = np.arange(0, d.x1_all.shape[0]) # example_xs = d.x1_all[inds[::floor(np.sqrt(len(inds)))]].clone() example_xs = d.x1_all[[0, 3, 5, 12]].clone() xs_fft = torch.fft((example_xs - example_xs.mean())[..., None][..., [0, 0]], 2).norm(p=2, dim=-1) xs_fft = torch.from_numpy(np.fft.fftshift(xs_fft.cpu(), axes=(-1, -2))).cuda() opt = torch.optim.Adam(d.net.parameters(), lr=2e-5) opt2 = torch.optim.Adam(d.net2.parameters(), lr=2e-5) lossdist = torch.zeros(d.x1_all.shape[0]) - 2 patch_size = d.x1_all.shape[2:] plt.figure() for e in range(ta.e, end_epoch): ta.e = e np.random.shuffle(inds) ta.lossdists.append(lossdist.numpy().copy()) lossdist[...] = -1 print(f"\r epoch {e}", end="") for b in range(ceil(d.x1_all.shape[0] / batch_size)): idxs = inds[b * batch_size:(b + 1) * batch_size] x1 = d.x1_all[idxs] #.cuda() def random_pixel_mask(): n = int(np.prod(patch_size) * 0.02) x_inds = np.random.randint(0, patch_size[1], n) y_inds = np.random.randint(0, patch_size[0], n) # z_inds = np.random.randint(0,32,64*64*1) ma = np.zeros(patch_size) ma[y_inds, x_inds] = 2 return ma def sparse_3set_mask(p=0.02, xs=[1, 2], ys=[]): "build random mask for small number of central pixels" n = int(np.prod(patch_size) * p) x_inds = np.random.randint(0, patch_size[1], n) y_inds = np.random.randint(0, patch_size[0], n) ma = np.zeros(patch_size) # ma = binary_dilation(ma) for i in xs: m = x_inds - i >= 0 ma[y_inds[m], x_inds[m] - i] = 1 m = x_inds + i < patch_size[1] ma[y_inds[m], x_inds[m] + i] = 1 for i in ys: m = y_inds - i >= 0 ma[y_inds[m] - i, x_inds[m]] = 1 m = y_inds + i < patch_size[0] ma[y_inds[m] + i, x_inds[m]] = 1 ma = ma.astype(np.uint8) ma[y_inds, x_inds] = 2 return ma def checkerboard_mask(): ma = np.indices(patch_size).transpose((1, 2, 0)) ma = np.floor(ma / (1, 256)).sum(-1) % 2 == 0 ma = 2 * ma if e % 2 == 1: ma = 2 - ma return ma ma = sparse_3set_mask(xs=[1, 2]).astype(np.float) ma2 = sparse_3set_mask(xs=[1, 2]).astype(np.float) # ipdb.set_trace() ## apply mask to input ma = torch.from_numpy(ma).cuda() x1_damaged = x1.clone() x1_damaged[:, :, ma > 0] = torch.rand(x1.shape).cuda()[:, :, ma > 0] y1p = d.net(x1_damaged) ma2 = torch.from_numpy(ma2).cuda() y1p_damaged = y1p.clone() y1p_damaged[:, :, ma2 > 0] = torch.rand(y1p.shape).cuda()[:, :, ma2 > 0] y2p = d.net2(y1p) dims = (1, 2, 3) ## all dims except batch tm1 = (ma == 2).float().repeat(4, 1, 1, 1) ## target mask tm2 = (ma2 == 2).float().repeat(4, 1, 1, 1) loss_per_patch = (tm1 * torch.abs(y1p - x1)**2).sum(dims) / tm1.sum(dims) loss_per_patch += ( tm2 * torch.abs(y2p - y1p)**2).sum(dims) / tm2.sum(dims) lossdist[idxs] = loss_per_patch.detach().cpu() loss = loss_per_patch.mean() ta.losses.append(float(loss)) opt.zero_grad() opt2.zero_grad() loss.backward() opt.step() opt2.step() ## predict on examples and save each epoch with torch.no_grad(): example_yp = d.net(example_xs) example_yp2 = d.net2(example_yp) yp_fft = torch.fft( (example_yp2 - example_yp2.mean())[..., None][..., [0, 0]], 2).norm(p=2, dim=-1) #.cpu().detach().numpy() yp_fft = torch.from_numpy( np.fft.fftshift(yp_fft.cpu(), axes=(-1, -2))).cuda() # yp_fft = yp_fft/yp_fft.max() rgb = torch.stack([ example_xs, ma.float().repeat(4, 1, 1, 1) / 2, xs_fft, example_yp2, yp_fft ], 0).cpu().detach().numpy() arr = rgb.copy() # type,samples,channels,y,x rgb = normalize3(rgb, axs=(1, 2, 3, 4)) rgb[[2, 4]] = normalize3(rgb[[2, 4]], pmin=0, pmax=99.0, axs=(1, 2, 3, 4)) # remove channels and permute rgb = collapse2(rgb[:, :, 0], 'tsyx', 'sy,tx') # arr = collapse2(arr[:,:,0],'tsyx','sy,tx') with warnings.catch_warnings(): warnings.simplefilter("ignore") if e % 10 == 0: io.imsave(d.savedir / f'epochs/rgb_{e:03d}.png', rgb) if e % 100 == 0: np.save(d.savedir / f'epochs_npy/arr_{e:03d}.npy', arr) batches_per_epoch = ceil(d.x1_all.shape[0] / batch_size) epochs = np.arange(len(ta.losses)) / batches_per_epoch plt.clf() plt.plot(epochs, ta.losses) # plt.ylim(np.mean(ta.losses)-3*np.std(ta.losses),np.mean(ta.losses)+3*np.std(ta.losses)) plt.yscale('log') plt.xlabel(f'1 epoch = {batches_per_epoch} batches') plt.savefig(d.savedir / f'loss.png', dpi=300) if e % 100 == 0: torch.save(d.net.state_dict(), savedir / f'models/net{e:03d}.pt') pklsave(ta.losses, d.savedir / f'losses.pkl') torch.save(d.net.state_dict(), d.savedir / f'models/net{ta.e:03d}.pt') return ta