def setup_flower_shutter(rawdata, savedir): savedir = Path(savedir) wipe_dirs(savedir) init_dirs(savedir) data = imread(rawdata) data = normalize3(data, 2, 99.6) ## normalize across all dims? d = SimpleNamespace() d.net = torch_models.Unet2_2d(16, [[1], [1]], finallayer=nn.ReLU) d.net.load_state_dict( torch.load( '/lustre/projects/project-broaddus/denoise_experiments/flower/models/net_randinit.pt' )) # d.net.apply(init_weights); d.savedir = savedir d.xs = torch.from_numpy(data).float() d.xs = d.xs.reshape(100, 4, 256, 4, 256, 1).permute(0, 1, 3, 5, 2, 4) #.reshape((-1,256,256)) d.ys = d.xs.mean(0) io.imsave(savedir / 'xs.png', collapse2(d.xs[0, :, :, 0].numpy(), "12yx", "1y,2x")) io.imsave(savedir / 'ys.png', collapse2(d.ys[:, :, 0].numpy(), "12yx", "1y,2x")) d.cuda = False return d
def datagen(params={}, savedir=None): data = [] times = np.r_[:190] for i in times: img = imread( f'/lustre/projects/project-broaddus/devseg_data/raw/celegans_isbi/Fluo-N3DH-CE/01/t{i:03d}.tif' ) pmin, pmax = np.random.uniform(1, 3), np.random.uniform(99.5, 99.8) img = normalize3(img, pmin, pmax).astype(np.float32, copy=False) slicelist = [] def random_patch(): ss = random_slice(img.shape, (32, 64, 64)) ## select patches with interesting content. 0.02 is chosen by manual inspection. while img[ss].mean() < 0.03: ss = random_slice(img.shape, (32, 64, 64)) x = img[ss].copy() slicelist.append(ss) ## augment # noiselevel = 0.2 # x += np.random.uniform(0,noiselevel,(1,)*3)*np.random.uniform(-1,1,x.shape) # for d in [0,1,2]: # if np.random.rand() < 0.5: # x = np.flip(x,d) return (x, ) data.append([random_patch() for _ in range(10)]) #ts(xys)czyx data = np.array(data) print("data.shape: ", data.shape) if savedir: rgb = collapse2(data[:, :, :, 16], 'tscyx', 'ty,sx,c')[..., [0, 0, 0]] rgb = normalize3(rgb) io.imsave(savedir / 'data_xy_cele.png', rgb) rgb = collapse2(data[:, :, :, :, 32], 'tsczx', 'tz,sx,c')[..., [0, 0, 0]] rgb = normalize3(rgb) io.imsave(savedir / 'data_xz_cele.png', rgb) np.savez_compressed(savedir / 'data_cele.npz', data) pklsave(slicelist, savedir / 'slicelist_cele.pkl') return data
def setup(savedir): savedir = Path(savedir) wipe_dirs(savedir) init_dirs(savedir) # data = cl_datagen2.datagen_self_sup(s=4, savedir=savedir) # data = cl_datagen2.datagen_all_kinds(savedir=savedir) data = np.load( '/lustre/projects/project-broaddus/denoise_experiments/cele/e01/data_cele.npz' )['arr_0'] # data = datagen(savedir=savedir) data = collapse2(data[None], 'rtsczyx', 'c,ts,r,z,y,x')[0] d = SimpleNamespace() d.net = torch_models.Unet2(16, [[1], [1]], finallayer=nn.ReLU).cuda() d.net.load_state_dict( torch.load( '/lustre/projects/project-broaddus/denoise_experiments/flower/models/net_randinit3D.pt' )) # d.net.apply(init_weights); d.savedir = savedir # torch.save(d.net.state_dict(), '/lustre/projects/project-broaddus/devseg_data/cl_datagen/grid/net_randinit3D.pt') d.x1_all = torch.from_numpy(data).float().cuda() return d
def datagen(savedir=None): # img = imread(f'/lustre/projects/project-broaddus/rawdata/artifacts/flower.tif')[:10] img = imread( f'/lustre/projects/project-broaddus/denoise_experiments/flower/e02/pred_flower.tif' )[:10] # img = imread(f'/lustre/projects/project-broaddus/rawdata/artifacts/shutterclosed.tif')[0] print(img.shape) # pmin, pmax = np.random.uniform(1,3), np.random.uniform(99.5,99.8) pmin, pmax = 2, 99.6 print(f"pmin = {pmin}; pmax = {pmax}") img = normalize3(img, pmin, pmax).astype(np.float32, copy=False) data = img.reshape((-1, 4, 256, 4, 256)).transpose( (0, 1, 3, 2, 4)).reshape((-1, 1, 256, 256)) # patch_size = (256,256) # slicelist = [] # def random_patch(): # ss = random_slice(img.shape, patch_size) # ## select patches with interesting content. FIXME # while img[ss].mean() < 0.0: # ss = random_slice(img.shape, patch_size) # x = img[ss].copy() # slicelist.append(ss) # ## augment # # noiselevel = 0.2 # # x += np.random.uniform(0,noiselevel,(1,)*3)*np.random.uniform(-1,1,x.shape) # # for d in [0,1,2]: # # if np.random.rand() < 0.5: # # x = np.flip(x,d) # return (x,) # data = np.array([random_patch() for _ in range(24)]) # data = np.load('../../devseg_data/cl_datagen/d003/data.npz') print("data.shape: ", data.shape) #SCZYX if savedir: rgb = collapse2(data[:, :], 'scyx', 's,y,x,c')[..., [0, 0, 0]] rgb = normalize3(rgb) rgb = plotgrid([rgb], 10) io.imsave(savedir / 'data_xy_flower.png', rgb) np.savez_compressed(savedir / 'data_flower.npz', data=data, pmin=pmin, pmax=pmax) # pklsave(slicelist, savedir/'slicelist2.pkl') dg = SimpleNamespace() dg.data = data dg.pmin = pmin dg.pmax = pmax return dg
def train(d, ta=None, end_epoch=300, mask_shape=[1, 2, 3, 4]): if ta is None: ta = init_training_artifacts() ## set up const variables necessary for training batch_size = 4 inds = np.arange(0, d.x1_all.shape[0]) patch_size = d.x1_all.shape[2:] d.w1_all = torch.ones(d.x1_all.shape).float() ## set up variables for monitoring training # d.eg_xs = d.x1_all[inds[::floor(np.sqrt(len(inds)))]].clone() d.eg_xs = d.x1_all[[0, 3, 5, 12]].clone() d.xs_fft = torch.fft((d.eg_xs - d.eg_xs.mean())[..., None][..., [0, 0]], 2).norm(p=2, dim=-1) d.xs_fft = torch.from_numpy(np.fft.fftshift(d.xs_fft, axes=(-1, -2))) lossdist = torch.zeros(d.x1_all.shape[0]) - 2 ## move everything to cuda d.net = d.net.cuda() d.x1_all = d.x1_all.cuda() d.w1_all = d.w1_all.cuda() d.xs_fft = d.xs_fft.cuda() d.eg_xs = d.eg_xs.cuda() opt = torch.optim.Adam(d.net.parameters(), lr=2e-5) plt.figure() for e in range(ta.e, end_epoch + 1): ta.e = e np.random.shuffle(inds) lossdist[...] = -1 print(f"\r epoch {e}", end="") for b in range(ceil(d.x1_all.shape[0] / batch_size)): idxs = inds[b * batch_size:(b + 1) * batch_size] x1 = d.x1_all[idxs] #.cuda() w1 = d.w1_all[idxs] #.cuda() def random_pixel_mask(): n = int(np.prod(patch_size) * 0.02) x_inds = np.random.randint(0, patch_size[1], n) y_inds = np.random.randint(0, patch_size[0], n) # z_inds = np.random.randint(0,32,64*64*1) ma = np.zeros(patch_size) ma[y_inds, x_inds] = 2 return ma def sparse_3set_mask(): "build random mask for small number of central pixels" n = int(np.prod(patch_size) * 0.02) x_inds = np.random.randint(0, patch_size[1], n) y_inds = np.random.randint(0, patch_size[0], n) ma = np.zeros(patch_size) # ma = binary_dilation(ma) for i in mask_shape: m = x_inds - i >= 0 ma[y_inds[m], x_inds[m] - i] = 1 m = x_inds + i < patch_size[1] ma[y_inds[m], x_inds[m] + i] = 1 # for i in [1]: # m = y_inds-i >= 0; ma[y_inds[m]-i,x_inds[m]] = 1 # m = y_inds+i < patch_size[0]; ma[y_inds[m]+i,x_inds[m]] = 1 ma = ma.astype(np.uint8) ma[y_inds, x_inds] = 2 return ma def checkerboard_mask(): ma = np.indices(patch_size).transpose((1, 2, 0)) ma = np.floor(ma / (1, 256)).sum(-1) % 2 == 0 ma = 2 * ma if e % 2 == 1: ma = 2 - ma return ma ma = sparse_3set_mask() # ipdb.set_trace() # return ma ## apply mask to input w1[:, :] = torch.from_numpy(ma.astype(np.float)).cuda() x1_damaged = x1.clone() x1_damaged[w1 > 0] = torch.rand(x1.shape).cuda()[w1 > 0] y1p = d.net(x1_damaged) dims = (1, 2, 3) ## all dims except batch if False: dx = 0.15 * torch.abs(y1p[:, :, :, 1:] - y1p[:, :, :, :-1]) dy = 0.15 * torch.abs(y1p[:, :, 1:] - y1p[:, :, :-1]) dy = 0.25 * torch.abs(y1p[:, :, :, 1:] - y1p[:, :, :, :-1]) dz = 0.05 * torch.abs(y1p[:, :, 1:] - y1p[:, :, :-1]) c0, c1, c2 = 0.0, 0.15, 1.0 potential = 2e2 * ( (y1p - c0)**2 * (y1p - c2)**2) ## rough locations for three classes resid = torch.abs(y1p - x1)**2 loss_per_patch = resid.mean(dims) + dx.mean( dims ) #+ dy.mean(dims) + dz.mean(dims) #+ potential.mean(dims) tm = (w1 == 2).float() ## target mask loss_per_patch = (tm * torch.abs(y1p - x1)**2).sum(dims) / tm.sum( dims) # + dx.mean(dims) + dy.mean(dims) #+ dz.mean(dims) # ipdb.set_trace() # loss_per_patch = (w1 * torch.abs(y1p-y1t)).sum(dims) / w1.sum(dims) #+ 1e-3*(y1p.mean(dims)).abs() # loss_per_patch = (w1 * -(y1t*torch.log(y1p + 1e-7) + (1-y1t)*torch.log((1-y1p) + 1e-7))).sum(dims) / w1.sum(dims) #+ 1e-2*(y1p.mean(dims)).abs() lossdist[idxs] = loss_per_patch.detach().cpu() loss = loss_per_patch.mean() ta.losses.append(float(loss)) opt.zero_grad() loss.backward() opt.step() ## predict on examples and save predictions as images with torch.no_grad(): example_yp = d.net(d.eg_xs) # d.xs_fft = d.xs_fft/d.xs_fft.max() yp_fft = torch.fft( (example_yp - example_yp.mean())[..., None][..., [0, 0]], 2).norm(p=2, dim=-1) #.cpu().detach().numpy() yp_fft = torch.from_numpy( np.fft.fftshift(yp_fft.cpu(), axes=(-1, -2))).cuda() # yp_fft = yp_fft/yp_fft.max() rgb = torch.stack([ d.eg_xs, w1[[0] * len(d.eg_xs)] / 2, d.xs_fft, example_yp, yp_fft ], 0).cpu().detach().numpy() arr = rgb.copy() # type,samples,channels,y,x rgb = normalize3(rgb, axs=(1, 2, 3, 4)) rgb[[2, 4]] = normalize3(rgb[[2, 4]], pmin=0, pmax=99.0, axs=(1, 2, 3, 4)) # return rgb # remove channels and permute rgb = collapse2(rgb[:, :, 0], 'tsyx', 'sy,tx') # arr = collapse2(arr[:,:,0],'tsyx','sy,tx') with warnings.catch_warnings(): warnings.simplefilter("ignore") if e % 10 == 0: io.imsave(d.savedir / f'epochs/rgb_{e:03d}.png', rgb) if e % 100 == 0: np.save(d.savedir / f'epochs_npy/arr_{e:03d}.npy', arr) ## plot the loss after each epoch ta.lossdists.append(lossdist.numpy().copy()) batches_per_epoch = ceil(d.x1_all.shape[0] / batch_size) x_axis = np.arange(len(ta.losses)) / batches_per_epoch plt.clf() plt.plot(x_axis, ta.losses) # plt.ylim(np.mean(ta.losses)-3*np.std(ta.losses),np.mean(ta.losses)+3*np.std(ta.losses)) plt.yscale('log') plt.xlabel(f'1 epoch = {batches_per_epoch} batches') plt.savefig(d.savedir / f'loss.png', dpi=300) ## and save the model state if e % 50 == 0: torch.save(d.net.state_dict(), d.savedir / f'models/net{e:03d}.pt') pklsave(ta.losses, d.savedir / f'losses.pkl') torch.save(d.net.state_dict(), d.savedir / f'models/net{ta.e:03d}.pt') return ta
def train(d, ta=None, end_epoch=300, already_on_cuda=False): if ta is None: ta = init_training_artifacts() ## setup const variables necessary for training batch_size = 4 inds = np.arange(0, d.xs.shape[0]) patch_size = d.xs.shape[4:] # xs = d.xs.reshape((100,4,256,4,256)).permute((0,1,3,2,4)) #.reshape((-1,256,256)) # ys = d.xs.mean(0).reshape((4,256,4,256)).permute((0,2,1,3)) d.ws = torch.ones(d.xs.shape).float() ## set up variables for monitoring training # d.example_xs = d.xs[inds[::floor(np.sqrt(len(inds)))]].clone() d.example_xs = d.xs[[0, 3, 5, 12], 0, 0].reshape(-1, 1, 256, 256).clone().cpu() d.xs_fft = torch.fft( (d.example_xs - d.example_xs.mean())[..., None][..., [0, 0]], 2).norm(p=2, dim=-1) d.xs_fft = torch.from_numpy(np.fft.fftshift(d.xs_fft, axes=(-1, -2))) lossdist = torch.zeros(d.xs.shape[0]) - 2 ## move vars to gpu # if d.cuda is False: d.net = d.net.cuda() d.xs = d.xs.cuda() d.ys = d.ys.cuda() d.xs_fft = d.xs_fft.cuda() d.example_xs = d.example_xs.cuda() d.ws = d.ws.cuda() ## initialize optimizer (must be done after moving data to gpu ?) opt = torch.optim.Adam(d.net.parameters(), lr=2e-4) plt.figure() for e in range(ta.e, end_epoch + 1): ta.e = e np.random.shuffle(inds) ta.lossdists.append(lossdist.numpy().copy()) lossdist[...] = -1 print(f"\r epoch {e}", end="") for b in range(ceil(d.xs.shape[0] / batch_size)): idxs = inds[b * batch_size:(b + 1) * batch_size] x1 = d.xs[idxs] w1 = d.ws[idxs] # y1 = d.ys[idxs] x1 = x1.reshape(-1, 1, 256, 256) y1p = d.net(x1) # x1 = x1.reshape(-1,4,4,256,256) # y1p = y1p.reshape(-1,4,4,256,256) y1p = y1p.reshape(4, 4, 4, 1, 256, 256) # ipdb.set_trace() dims = (1, 2, 3, 4, 5) ## all dims except batch # ipdb.set_trace() loss_per_patch = ((y1p - d.ys)**2).mean(dims) # loss_per_patch = (w1 * torch.abs(y1p-y1t)).sum(dims) / w1.sum(dims) #+ 1e-3*(y1p.mean(dims)).abs() # loss_per_patch = (w1 * -(y1t*torch.log(y1p + 1e-7) + (1-y1t)*torch.log((1-y1p) + 1e-7))).sum(dims) / w1.sum(dims) #+ 1e-2*(y1p.mean(dims)).abs() lossdist[idxs] = loss_per_patch.detach().cpu() loss = loss_per_patch.mean() ta.losses.append(float(loss)) opt.zero_grad() loss.backward() opt.step() ## predict on examples and save each epoch if e % 10 == 0: with torch.no_grad(): example_yp = d.net(d.example_xs) ## compute fft from predictions yp_fft = torch.fft( (example_yp - example_yp.mean())[..., None][..., [0, 0]], 2).norm(p=2, dim=-1) #.cpu().detach().numpy() ## shift frequency domain s.t. zer freq is at center of array yp_fft = torch.from_numpy( np.fft.fftshift(yp_fft.cpu(), axes=(-1, -2))).cuda() ## stack (real space, -weights-, real fft, predictions, and prediction fft) along a new dimension rgb = torch.stack([d.example_xs, d.xs_fft, example_yp, yp_fft], 0).cpu().detach().numpy() arr = rgb.copy() ## first normalize each type to [0,1] independently rgb = normalize3(rgb, axs=(1, 2, 3, 4)) # dims=type,samples,channels,y,x ## then normalize fft's and real-space dims separately rgb[[1, 3]] = normalize3(rgb[[1, 3]], pmin=0, pmax=99.0, axs=(1, 2, 3, 4)) ## remove channels and permute into a 2D image rgb = collapse2(rgb[:, :, 0], 'tsyx', 'sy,tx') with warnings.catch_warnings(): warnings.simplefilter("ignore") if e % 10 == 0: io.imsave(d.savedir / f'epochs/rgb_{e:03d}.png', rgb) if e % 100 == 0: np.save(d.savedir / f'epochs_npy/arr_{e:03d}.npy', arr) ## plot loss batches_per_epoch = ceil(d.xs.shape[0] / batch_size) epochs = np.arange(len(ta.losses)) / batches_per_epoch plt.clf() plt.plot(epochs, ta.losses) # plt.ylim(np.mean(ta.losses)-3*np.std(ta.losses),np.mean(ta.losses)+3*np.std(ta.losses)) plt.yscale('log') plt.xlabel(f'1 epoch = {batches_per_epoch} batches') plt.savefig(d.savedir / f'loss.png', dpi=300) ## save model weights if e % 100 == 0: torch.save(d.net.state_dict(), d.savedir / f'models/net{e:03d}.pt') pklsave(ta.losses, d.savedir / f'losses.pkl') torch.save(d.net.state_dict(), d.savedir / f'models/net{ta.e:03d}.pt') return ta
def train(d, ta=None, end_epoch=301): if ta is None: ta = init_training_artifacts() batch_size = 4 inds = np.arange(0, d.x1_all.shape[0]) # example_xs = d.x1_all[inds[::floor(np.sqrt(len(inds)))]].clone() example_xs = d.x1_all[[0, 3, 5, 12]].clone() xs_fft = torch.fft((example_xs - example_xs.mean())[..., None][..., [0, 0]], 2).norm(p=2, dim=-1) xs_fft = torch.from_numpy(np.fft.fftshift(xs_fft.cpu(), axes=(-1, -2))).cuda() opt = torch.optim.Adam(d.net.parameters(), lr=2e-5) opt2 = torch.optim.Adam(d.net2.parameters(), lr=2e-5) lossdist = torch.zeros(d.x1_all.shape[0]) - 2 patch_size = d.x1_all.shape[2:] plt.figure() for e in range(ta.e, end_epoch): ta.e = e np.random.shuffle(inds) ta.lossdists.append(lossdist.numpy().copy()) lossdist[...] = -1 print(f"\r epoch {e}", end="") for b in range(ceil(d.x1_all.shape[0] / batch_size)): idxs = inds[b * batch_size:(b + 1) * batch_size] x1 = d.x1_all[idxs] #.cuda() def random_pixel_mask(): n = int(np.prod(patch_size) * 0.02) x_inds = np.random.randint(0, patch_size[1], n) y_inds = np.random.randint(0, patch_size[0], n) # z_inds = np.random.randint(0,32,64*64*1) ma = np.zeros(patch_size) ma[y_inds, x_inds] = 2 return ma def sparse_3set_mask(p=0.02, xs=[1, 2], ys=[]): "build random mask for small number of central pixels" n = int(np.prod(patch_size) * p) x_inds = np.random.randint(0, patch_size[1], n) y_inds = np.random.randint(0, patch_size[0], n) ma = np.zeros(patch_size) # ma = binary_dilation(ma) for i in xs: m = x_inds - i >= 0 ma[y_inds[m], x_inds[m] - i] = 1 m = x_inds + i < patch_size[1] ma[y_inds[m], x_inds[m] + i] = 1 for i in ys: m = y_inds - i >= 0 ma[y_inds[m] - i, x_inds[m]] = 1 m = y_inds + i < patch_size[0] ma[y_inds[m] + i, x_inds[m]] = 1 ma = ma.astype(np.uint8) ma[y_inds, x_inds] = 2 return ma def checkerboard_mask(): ma = np.indices(patch_size).transpose((1, 2, 0)) ma = np.floor(ma / (1, 256)).sum(-1) % 2 == 0 ma = 2 * ma if e % 2 == 1: ma = 2 - ma return ma ma = sparse_3set_mask(xs=[1, 2]).astype(np.float) ma2 = sparse_3set_mask(xs=[1, 2]).astype(np.float) # ipdb.set_trace() ## apply mask to input ma = torch.from_numpy(ma).cuda() x1_damaged = x1.clone() x1_damaged[:, :, ma > 0] = torch.rand(x1.shape).cuda()[:, :, ma > 0] y1p = d.net(x1_damaged) ma2 = torch.from_numpy(ma2).cuda() y1p_damaged = y1p.clone() y1p_damaged[:, :, ma2 > 0] = torch.rand(y1p.shape).cuda()[:, :, ma2 > 0] y2p = d.net2(y1p) dims = (1, 2, 3) ## all dims except batch tm1 = (ma == 2).float().repeat(4, 1, 1, 1) ## target mask tm2 = (ma2 == 2).float().repeat(4, 1, 1, 1) loss_per_patch = (tm1 * torch.abs(y1p - x1)**2).sum(dims) / tm1.sum(dims) loss_per_patch += ( tm2 * torch.abs(y2p - y1p)**2).sum(dims) / tm2.sum(dims) lossdist[idxs] = loss_per_patch.detach().cpu() loss = loss_per_patch.mean() ta.losses.append(float(loss)) opt.zero_grad() opt2.zero_grad() loss.backward() opt.step() opt2.step() ## predict on examples and save each epoch with torch.no_grad(): example_yp = d.net(example_xs) example_yp2 = d.net2(example_yp) yp_fft = torch.fft( (example_yp2 - example_yp2.mean())[..., None][..., [0, 0]], 2).norm(p=2, dim=-1) #.cpu().detach().numpy() yp_fft = torch.from_numpy( np.fft.fftshift(yp_fft.cpu(), axes=(-1, -2))).cuda() # yp_fft = yp_fft/yp_fft.max() rgb = torch.stack([ example_xs, ma.float().repeat(4, 1, 1, 1) / 2, xs_fft, example_yp2, yp_fft ], 0).cpu().detach().numpy() arr = rgb.copy() # type,samples,channels,y,x rgb = normalize3(rgb, axs=(1, 2, 3, 4)) rgb[[2, 4]] = normalize3(rgb[[2, 4]], pmin=0, pmax=99.0, axs=(1, 2, 3, 4)) # remove channels and permute rgb = collapse2(rgb[:, :, 0], 'tsyx', 'sy,tx') # arr = collapse2(arr[:,:,0],'tsyx','sy,tx') with warnings.catch_warnings(): warnings.simplefilter("ignore") if e % 10 == 0: io.imsave(d.savedir / f'epochs/rgb_{e:03d}.png', rgb) if e % 100 == 0: np.save(d.savedir / f'epochs_npy/arr_{e:03d}.npy', arr) batches_per_epoch = ceil(d.x1_all.shape[0] / batch_size) epochs = np.arange(len(ta.losses)) / batches_per_epoch plt.clf() plt.plot(epochs, ta.losses) # plt.ylim(np.mean(ta.losses)-3*np.std(ta.losses),np.mean(ta.losses)+3*np.std(ta.losses)) plt.yscale('log') plt.xlabel(f'1 epoch = {batches_per_epoch} batches') plt.savefig(d.savedir / f'loss.png', dpi=300) if e % 100 == 0: torch.save(d.net.state_dict(), savedir / f'models/net{e:03d}.pt') pklsave(ta.losses, d.savedir / f'losses.pkl') torch.save(d.net.state_dict(), d.savedir / f'models/net{ta.e:03d}.pt') return ta