def get_data(n_images, axes, shape): def _gen(): for i in range(n_images): x = rng.uniform(size=shape) y = 5 + 3*x yield x, y, axes, None return RawData(_gen, n_images, '')
def get_data(n_images, axes, shape): red_n = rng.choice(len(axes) - 1) + 1 red_axes = ''.join(rng.choice(tuple(axes), red_n, replace=False)) keepdims = rng.choice((True, False)) def _gen(): for i in range(n_images): x = rng.uniform(size=shape) y = np.mean(x, axis=tuple(axes_dict(axes)[a] for a in red_axes), keepdims=keepdims) yield x, y, axes, None return RawData(_gen, n_images, ''), red_axes, keepdims
def gen_raw_data(n: int, **kwargs): """ Generate a csbdeep RawData object with random images generated by :func:`my text <dispim.neural.datagen.gen_training_data>` :param n: :param use_noise: :param use_psf: :param use_subsampling: :param shape: :return: """ images_degr, images, psf_as, psf_bs = gen_training_data(n, **kwargs) return RawData.from_arrays(images_degr, images, axes='XYZC'), psf_as, psf_bs
def create_patches(self, normalization=None): for ch in self.train_channels: n_images = len( list((pathlib.Path(self.out_dir) / "train_data" / "raw" / "CH_{}".format(ch) / "GT").glob("*.tif"))) print("-- Creating {} patches for channel: {}".format( n_images * self.n_patches_per_image, ch)) raw_data = RawData.from_folder( basepath=pathlib.Path(self.out_dir) / "train_data" / "raw" / "CH_{}".format(ch), source_dirs=["low"], target_dir="GT", axes=self.axes, ) if normalization is not None: X, Y, XY_axes = create_patches( raw_data=raw_data, patch_size=self.patch_size, n_patches_per_image=self.n_patches_per_image, save_file=self.get_training_patch_path() / "CH_{}_training_patches.npz".format(ch), verbose=False, normalization=normalization, ) else: X, Y, XY_axes = create_patches( raw_data=raw_data, patch_size=self.patch_size, n_patches_per_image=self.n_patches_per_image, save_file=self.get_training_patch_path() / "CH_{}_training_patches.npz".format(ch), verbose=False, ) plt.figure(figsize=(16, 4)) rand_sel = numpy.random.randint(low=0, high=len(X), size=6) plot_some(X[rand_sel, 0], Y[rand_sel, 0], title_list=[range(6)], cmap="gray") plt.show() print("Done") return
def _create(img_size,img_axes,patch_size,patch_axes): U,V = (rng.uniform(size=(n_images,)+img_size) for _ in range(2)) X,Y,XYaxes = create_patches ( raw_data = RawData.from_arrays(U,V,img_axes), patch_size = patch_size, patch_axes = patch_axes, n_patches_per_image = n_patches_per_image, save_file = save_file ) (_X,_Y), val_data, _XYaxes = load_training_data(save_file,verbose=True) assert val_data is None assert _XYaxes[-1 if backend_channels_last else 1] == 'C' _X,_Y = (move_image_axes(u,fr=_XYaxes,to=XYaxes) for u in (_X,_Y)) assert np.allclose(X,_X,atol=1e-6) assert np.allclose(Y,_Y,atol=1e-6) assert set(XYaxes) == set(_XYaxes) assert load_training_data(save_file,validation_split=0.5)[2] is not None assert all(len(x)==3 for x in load_training_data(save_file,n_images=3)[0])
def test_rawdata_from_folder(tmpdir): rng = np.random.RandomState(42) tmpdir = Path(str(tmpdir)) n_images, img_size, img_axes = 3, (64,64), 'YX' data = {'X' : rng.uniform(size=(n_images,)+img_size).astype(np.float32), 'Y' : rng.uniform(size=(n_images,)+img_size).astype(np.float32)} for name,images in data.items(): (tmpdir/name).mkdir(exist_ok=True) for i,img in enumerate(images): imsave(str(tmpdir/name/('img_%02d.tif'%i)),img) raw_data = RawData.from_folder(str(tmpdir),['X'],'Y',img_axes) assert raw_data.size == n_images for i,(x,y,axes,mask) in enumerate(raw_data.generator()): assert mask is None assert axes == img_axes assert any(np.allclose(x,u) for u in data['X']) assert any(np.allclose(y,u) for u in data['Y'])
def generate_2D_patch_training_data(BaseDirectory, SaveNpzDirectory, SaveName, patch_size=(512, 512), n_patches_per_image=64, transforms=None): raw_data = RawData.from_folder( basepath=BaseDirectory, source_dirs=['Original'], target_dir='BinaryMask', axes='YX', ) X, Y, XY_axes = create_patches( raw_data=raw_data, patch_size=patch_size, n_patches_per_image=n_patches_per_image, transforms=transforms, save_file=SaveNpzDirectory + SaveName, )
def data_generation(data_path, axes, patch_size, data_name): """Generates training data for training CARE network. `RawData` object defines how to get the pairs of low/high SNR stacks and the semantics of each axis. We have two folders "noisy" and "clean" where corresponding low and high-SNR stacks are TIFF images with identical filenames. Parameters ---------- data_path : str Path of the input data containing 'noisy' and 'clean' folder axes : str Semantic order each axes patch_size : tuple Size of the patches to crop the images data_name : str Name of the .npz file containing the pairs of images. """ raw_data = RawData.from_folder( basepath=data_path, source_dirs=['noisy'], target_dir='clean', axes=axes, ) # Patch size that is a power of two along XYZT, or at least divisible by 8. # By convention, the variable name `X` (or `x`) refers to an input variable for a machine learning model, whereas `Y` (or `y`) indicates an output variable. X, Y, XY_axes = create_patches( raw_data=raw_data, patch_size=patch_size, n_patches_per_image=100, save_file='data/' + data_name + '.npz', ) assert X.shape == Y.shape print("shape of X,Y =", X.shape) print("axes of X,Y =", XY_axes)
imgList = os.listdir(train_dir) labelList = os.listdir(label_dir) imgArray = [] for image in tqdm(imgList, 'Reading img'): imgArray.append(imread(os.path.join(train_dir, image))) labelArray = [] for label in tqdm(labelList, 'Reading label'): labelArray.append(imread(os.path.join(label_dir, label))) print(imgArray[0].shape) print(labelArray[0].shape) raw_data = RawData.from_arrays(imgArray, labelArray, axes='YX') X, Y, XY_axes = create_patches( raw_data=raw_data, patch_size=(128, 128, 1), patch_axes='YXC', n_patches_per_image=25, save_file= '/mnt/AE3205C73205958D/Data/3dliver_local/pc_adult/2d_slices/imagesXY/image_full/mydata_128x128patch.npz' ) (X, Y), (X_val, Y_val), axes = load_training_data( '/mnt/AE3205C73205958D/Data/3dliver_local/pc_adult/2d_slices/imagesXY/image_full/mydata_128x128patch.npz', validation_split=0.1, verbose=True)
def Train(self): BinaryName = 'BinaryMask/' RealName = 'RealMask/' Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif')) Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True) Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True) RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif')) ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif')) ValRealMask = sorted( glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif')) print('Instance segmentation masks:', len(RealMask)) if len(RealMask) == 0: print('Making labels') Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) for fname in Mask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = label(image) imwrite((self.BaseDir + '/' + RealName + Name + '.tif'), Binaryimage.astype('uint16')) Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) print('Semantic segmentation masks:', len(Mask)) if len(Mask) == 0: print('Generating Binary images') RealfilesMask = sorted( glob.glob(self.BaseDir + '/' + RealName + '*tif')) for fname in RealfilesMask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = image > 0 imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'), Binaryimage.astype('uint16')) if self.GenerateNPZ: raw_data = RawData.from_folder( basepath=self.BaseDir, source_dirs=['Raw/'], target_dir='BinaryMask/', axes='ZYX', ) X, Y, XY_axes = create_patches( raw_data=raw_data, patch_size=(self.PatchZ, self.PatchY, self.PatchX), n_patches_per_image=self.n_patches_per_image, save_file=self.BaseDir + self.NPZfilename + '.npz', ) # Training UNET model if self.TrainUNET: print('Training UNET model') load_path = self.BaseDir + self.NPZfilename + '.npz' (X, Y), (X_val, Y_val), axes = load_training_data(load_path, validation_split=0.1, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] config = Config(axes, n_channel_in, n_channel_out, unet_n_depth=self.depth, train_epochs=self.epochs, train_batch_size=self.batch_size, unet_n_first=self.startfilter, train_loss='mse', unet_kern_size=self.kern_size, train_learning_rate=self.learning_rate, train_reduce_lr={ 'patience': 5, 'factor': 0.5 }) print(config) vars(config) model = CARE(config, name='UNET' + self.model_name, basedir=self.model_dir) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') model.load_weights(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5') history = model.train(X, Y, validation_data=(X_val, Y_val)) print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) if self.TrainSTAR: print('Training StarDistModel model with', self.backbone, 'backbone') self.axis_norm = (0, 1, 2) if self.CroppedLoad == False: assert len(Raw) > 1, "not enough training data" print(len(Raw)) rng = np.random.RandomState(42) ind = rng.permutation(len(Raw)) X_train = list(map(ReadFloat, Raw)) Y_train = list(map(ReadInt, RealMask)) self.Y = [ label(DownsampleData(y, self.DownsampleFactor)) for y in tqdm(Y_train) ] self.X = [ normalize(DownsampleData(x, self.DownsampleFactor), 1, 99.8, axis=self.axis_norm) for x in tqdm(X_train) ] n_val = max(1, int(round(0.15 * len(ind)))) ind_train, ind_val = ind[:-n_val], ind[-n_val:] self.X_val, self.Y_val = [self.X[i] for i in ind_val ], [self.Y[i] for i in ind_val] self.X_trn, self.Y_trn = [self.X[i] for i in ind_train ], [self.Y[i] for i in ind_train] print('number of images: %3d' % len(self.X)) print('- training: %3d' % len(self.X_trn)) print('- validation: %3d' % len(self.X_val)) if self.CroppedLoad: self.X_trn = self.DataSequencer(Raw, self.axis_norm, Normalize=True, labelMe=False) self.Y_trn = self.DataSequencer(RealMask, self.axis_norm, Normalize=False, labelMe=True) self.X_val = self.DataSequencer(ValRaw, self.axis_norm, Normalize=True, labelMe=False) self.Y_val = self.DataSequencer(ValRealMask, self.axis_norm, Normalize=False, labelMe=True) self.train_sample_cache = False print(Config3D.__doc__) anisotropy = (1, 1, 1) rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy) if self.backbone == 'resnet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, resnet_n_blocks=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', resnet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, resnet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1) if self.backbone == 'unet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, unet_n_depth=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', unet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, unet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1, train_sample_cache=False) print(conf) vars(conf) Starmodel = StarDist3D(conf, name=self.model_name, basedir=self.model_dir) print( Starmodel._axes_tile_overlap('ZYX'), os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5')) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_last.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_best.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_best.h5') historyStar = Starmodel.train(self.X_trn, self.Y_trn, validation_data=(self.X_val, self.Y_val), epochs=self.epochs) print(sorted(list(historyStar.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(historyStar, ['loss', 'val_loss'], [ 'dist_relevant_mae', 'val_dist_relevant_mae', 'dist_relevant_mse', 'val_dist_relevant_mse' ])
# print('image size =', x.shape) # print('Z subsample factor =', subsample) # plt.figure(figsize=(389/100, 389/100)) # plt.imshow(x, cmap='gray') # plt.show() # print('image size =', x.shape) # print('Z subsample factor =', subsample) # raw_data = RawData.from_folder ( # basepath = 'data', # source_dirs = ['simulator_data'], # target_dir = 'simulator_data', # axes = 'CYX', # ) raw_data = RawData.from_arrays(X=[x], Y=[x], axes='CYX') anisotropic_transform = anisotropic_distortions( subsample=1, psf=np.ones((3)) / 9, # use the actual PSF here psf_axes='Y', # poisson_noise = True, # gauss_sigma = 0.1 ) X, Y, XY_axes = create_patches( raw_data=raw_data, patch_size=(x.shape[0], x.shape[1], x.shape[2]), n_patches_per_image=1, transforms=[anisotropic_transform], )
def get_dataset(pd_scribbles,n_patches_per_image_train=30,n_patches_per_image_val=8,patch_size=(128, 128), p_label = 0.6,val_perc = 0.3,verbose = True, border = False): X_train = None X_val = None for i in range(len(pd_scribbles)): ## read image and label npz_read = np.load(pd_scribbles['input_dir'][i] + pd_scribbles['input_file'][i]) image = npz_read['image'] label = npz_read['label'] nuclei = np.zeros_like(label) nuclei[label > 0] = 1 ## read scribbles npz_read = np.load(pd_scribbles['input_dir'][i] + pd_scribbles['scribble_file'][i]) scribble = npz_read['scribble'] raw_image_in = image + 0 # normalize(image,pmin=pmin,pmax=pmax,clip = False) ## Sample validation mask patch_val_size = [int(image.shape[0] * val_perc), int(image.shape[1] * val_perc)] all_back = True while all_back: val_mask = np.zeros([raw_image_in.shape[0], raw_image_in.shape[1]]) ix_x = np.random.randint(0, raw_image_in.shape[0] - patch_val_size[0]) ix_y = np.random.randint(0, raw_image_in.shape[0] - patch_val_size[1]) val_mask[ix_x:ix_x + patch_val_size[0], ix_y:ix_y + patch_val_size[1]] = 1 if np.sum(val_mask * np.sum(scribble[...], axis=-1)) > 10: all_back = False ## Generate patches raw_data = RawData.from_arrays(raw_image_in[np.newaxis, ...], scribble[np.newaxis, ...]) ## for plot ## if verbose: aux = np.zeros([raw_image_in.shape[0], raw_image_in.shape[1], 3]) if len(raw_image_in.shape)>2: aux[..., 1] = np.sum(raw_image_in,axis=-1) * 0.8 else: aux[..., 1] = raw_image_in * 0.8 aux[..., 0] = scribble[..., 0] aux[..., 2] = np.sum(scribble[..., 1:], axis=2) ### for group in ['val', 'train']: if group == 'val': fov_mask = np.array(val_mask) n_patches_per_image = n_patches_per_image_val + 0 if verbose: plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.title('Validation FOV') plt.imshow(fov_mask[..., np.newaxis] * aux) else: fov_mask = 1 - np.array(val_mask) n_patches_per_image = n_patches_per_image_train + 0 if verbose: plt.subplot(1, 2, 2) plt.title('Train FOV') plt.imshow(fov_mask[..., np.newaxis] * aux) plt.show() X_aux, Y_aux, axes = generate_patches_syxc(raw_data, patch_size, int(n_patches_per_image * (1 - p_label)), normalization=None, patch_filter=None, fov_mask=fov_mask) n_patches_add = int(n_patches_per_image - X_aux.shape[0]) if n_patches_add > 0: X_labeled_aux, Y_labeled_aux, axes = generate_patches_syxc(raw_data, patch_size, n_patches_add, normalization=None, mask_filter_index=np.arange( scribble.shape[-1]), fov_mask=fov_mask) if X_labeled_aux is not None: X_aux = np.concatenate([X_aux, X_labeled_aux], axis=0) Y_aux = np.concatenate([Y_aux, Y_labeled_aux], axis=0) if group == 'val': if X_val is None: X_val = np.array(X_aux) Y_val = np.array(Y_aux) else: X_val = np.concatenate([X_val, X_aux], axis=0) Y_val = np.concatenate([Y_val, Y_aux], axis=0) else: if X_train is None: X_train = np.array(X_aux) Y_train = np.array(Y_aux) else: X_train = np.concatenate([X_train, X_aux], axis=0) Y_train = np.concatenate([Y_train, Y_aux], axis=0) print(Y_train.shape,Y_val.shape) if border: return X_train,Y_train,X_val,Y_val else: out_channels = int(Y_train.shape[-1]/3) Y_train_aux = np.zeros([Y_train.shape[0], Y_train.shape[1], Y_train.shape[2], out_channels * 2]) Y_val_aux = np.zeros([Y_val.shape[0], Y_val.shape[1], Y_val.shape[2], out_channels * 2]) # print(out_channels,Y_train.shape[2]) for j in np.arange(out_channels): # print(j*2,j*out_channels) Y_train_aux[..., 2*j] = np.array(Y_train[..., out_channels*j]) # foreground Y_train_aux[..., 2*j+1] = Y_train[..., out_channels*j+1] + Y_train[..., out_channels*j+2] # Border + background are background Y_val_aux[..., 2*j] = np.array(Y_val[..., out_channels*j]) # foreground Y_val_aux[..., 2*j+1] = Y_val[..., out_channels*j+1] + Y_val[..., out_channels*j+2] # Border + background are background return X_train,Y_train_aux,X_val,Y_val_aux
from argparse import ArgumentParser from csbdeep.data import RawData, create_patches args = ArgumentParser() args.add_argument("input_basepath") args.add_argument("input_x") args.add_argument("input_y") args.add_argument("output_file") args.add_argument("--patch_size_xy", default=32) args.add_argument("--patch_size_z", default=16) args.add_argument("--n_patches_per_image", default=750) args.add_argument("--axes", default="ZYX") args = args.parse_args() data = RawData.from_folder(basepath=args.input_basepath, source_dirs=[args.input_x], target_dir=args.input_y, axes=args.axes) ps_xy = int(args.patch_size_xy) ps_z = int(args.patch_size_z) X, Y, axes = create_patches(data, patch_size=(ps_z, ps_xy, ps_xy, 1), n_patches_per_image=int(args.n_patches_per_image), save_file=args.output_file, patch_axes=args.axes + "C")
def original(): x = np.load('data/img006_noconv.npy') x = np.moveaxis(x, 1, 2) ## initial axes are TZCYX, but we axes = 'CZYX' mypath = Path('isonet_psf_1') mypath.mkdir(exist_ok=True) subsample = 5.0 #10.2 print('image size =', x.shape) print('image axes =', axes) print('subsample factor =', subsample) plt.switch_backend('agg') if False: plt.figure(figsize=(15, 15)) plot_some(np.moveaxis(x[0], 1, -1)[[5, -5]], title_list=[['xy slice', 'xy slice']], pmin=2, pmax=99.8) plt.savefig(mypath / 'datagen_1.png') plt.figure(figsize=(15, 15)) plot_some(np.moveaxis(np.moveaxis(x[0], 1, -1)[:, [50, -50]], 1, 0), title_list=[['xz slice', 'xz slice']], pmin=2, pmax=99.8, aspect=subsample) plt.savefig(mypath / 'datagen_2.png') def gimmeit_gen(): ## iterate over time dimension for i in range(x.shape[0]): yield x[i], x[i], axes, None raw_data = RawData(gimmeit_gen, x.shape[0], "this is great!") ## initial idea if False: def buildkernel(): kern = np.exp(-(np.arange(10)**2 / 2)) kern /= kern.sum() kern = kern.reshape([1, 1, -1, 1]) kern = np.stack([kern, kern], axis=1) return kern psf_kern = buildkernel() ## use Martins theoretical psf if False: psf_aniso = imread('data/psf_aniso_NA_0.8.tif') psf_channels = np.stack([ psf_aniso, ] * 2, axis=1) def buildkernel(): kernel = np.zeros(20) kernel[7:13] = 1 / 6 ## reshape into CZYX. long axis is X. kernel = kernel.reshape([1, 1, -1]) ## repeate same kernel for both channels kernel = np.stack([kernel, kernel], axis=0) return kernel psf = buildkernel() print(psf.shape) ## use theoretical psf if False: psf_channels = np.load('data/measured_psfs.npy') psf = rotate(psf_channels, 90, axes=(1, 3)) iso_transform = data.anisotropic_distortions( subsample=subsample, psf=psf, ) X, Y, XY_axes = data.create_patches( raw_data=raw_data, patch_size=(2, 1, 128, 128), n_patches_per_image=256, transforms=[iso_transform], ) assert X.shape == Y.shape print("shape of X,Y =", X.shape) print("axes of X,Y =", XY_axes) # remove dummy z dim to obtain multi-channel 2D patches X = X[:, :, 0, ...] Y = Y[:, :, 0, ...] XY_axes = XY_axes.replace('Z', '') assert X.shape == Y.shape print("shape of X,Y =", X.shape) print("axes of X,Y =", XY_axes) np.savez(mypath / 'my_training_data.npz', X=X, Y=Y, axes=XY_axes) for i in range(2): plt.figure(figsize=(16, 4)) sl = slice(8 * i, 8 * (i + 1)) plot_some(np.moveaxis(X[sl], 1, -1), np.moveaxis(Y[sl], 1, -1), title_list=[np.arange(sl.start, sl.stop)]) plt.savefig(mypath / 'datagen_panel_{}.png'.format(i))
y = imread( os.path.join(args.base_dir, "high_snr", f"img_{img_number}.tif")) print(f"image size: {x.shape}") plt.figure(figsize=(16, 10)) plot_some( np.stack([x, y]), title_list=[['low snr', 'high snr']], ) plt.show() if not os.path.exists(os.path.join(args.base_dir, "test")): os.mkdir(os.path.join(args.base_dir, "test")) shutil.move( os.path.join(args.base_dir, "low_snr", f"img_{img_number}.tif"), os.path.join(args.base_dir, "test", "low_snr.tif")) shutil.move( os.path.join(args.base_dir, "high_snr", f"img_{img_number}.tif"), os.path.join(args.base_dir, "test", "high_snr.tif")) # Read the pairs, passing in the axis semantics raw_data = RawData.from_folder(basepath=args.base_dir, source_dirs=['low_snr'], target_dir='high_snr', axes='YX') # From the stacks, generate 2D patches, and save to the output_filename create_patches(raw_data=raw_data, patch_size=(128, 128), n_patches_per_image=512, save_file=os.path.join(args.base_dir, args.output_filename))
def training(self): #Loading Files and Plot examples basepath = 'data/' training_original_dir = 'training/original/' training_ground_truth_dir = 'training/ground_truth/' validation_original_dir = 'validation/original/' validation_ground_truth_dir = 'validation/ground_truth/' import glob from skimage import io from matplotlib import pyplot as plt training_original_files = sorted( glob.glob(basepath + training_original_dir + '*.tif')) training_original_file = io.imread(training_original_files[0]) training_ground_truth_files = sorted( glob.glob(basepath + training_ground_truth_dir + '*.tif')) training_ground_truth_file = io.imread(training_ground_truth_files[0]) print("Training dataset's number of files and dimensions: ", len(training_original_files), training_original_file.shape, len(training_ground_truth_files), training_ground_truth_file.shape) training_size = len(training_original_file) validation_original_files = sorted( glob.glob(basepath + validation_original_dir + '*.tif')) validation_original_file = io.imread(validation_original_files[0]) validation_ground_truth_files = sorted( glob.glob(basepath + validation_ground_truth_dir + '*.tif')) validation_ground_truth_file = io.imread( validation_ground_truth_files[0]) print("Validation dataset's number of files and dimensions: ", len(validation_original_files), validation_original_file.shape, len(validation_ground_truth_files), validation_ground_truth_file.shape) validation_size = len(validation_original_file) if training_size == validation_size: size = training_size else: print( 'Training and validation images should be of the same dimensions!' ) plt.figure(figsize=(16, 4)) plt.subplot(141) plt.imshow(training_original_file) plt.subplot(142) plt.imshow(training_ground_truth_file) plt.subplot(143) plt.imshow(validation_original_file) plt.subplot(144) plt.imshow(validation_ground_truth_file) #preparing inputs for NN from pairs of 32bit TIFF image files with intensities in range 0..1. . Run it only once for each new dataset from csbdeep.data import RawData, create_patches, no_background_patches training_data = RawData.from_folder( basepath=basepath, source_dirs=[training_original_dir], target_dir=training_ground_truth_dir, axes='YX', ) validation_data = RawData.from_folder( basepath=basepath, source_dirs=[validation_original_dir], target_dir=validation_ground_truth_dir, axes='YX', ) # pathces will be created further in "data augmentation" step, # that's why patch size here is the dimensions of images and number of pathes per image is 1 size1 = 64 X, Y, XY_axes = create_patches( raw_data=training_data, patch_size=(size1, size1), patch_filter=no_background_patches(0), n_patches_per_image=1, save_file=basepath + 'training.npz', ) X_val, Y_val, XY_axes = create_patches( raw_data=validation_data, patch_size=(size1, size1), patch_filter=no_background_patches(0), n_patches_per_image=1, save_file=basepath + 'validation.npz', ) #Loading training and validation data into memory from csbdeep.io import load_training_data (X, Y), _, axes = load_training_data(basepath + 'training.npz', verbose=False) (X_val, Y_val), _, axes = load_training_data(basepath + 'validation.npz', verbose=False) X.shape, Y.shape, X_val.shape, Y_val.shape from csbdeep.utils import axes_dict c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] batch = len( X ) # You should define number of batches according to the available memory #batch=1 seed = 1 from keras.preprocessing.image import ImageDataGenerator data_gen_args = dict(samplewise_center=False, samplewise_std_normalization=False, zca_whitening=False, fill_mode='reflect', rotation_range=30, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, vertical_flip=True #width_shift_range=0.2, #height_shift_range=0.2, ) # training image_datagen = ImageDataGenerator(**data_gen_args) mask_datagen = ImageDataGenerator(**data_gen_args) image_datagen.fit(X, augment=True, seed=seed) mask_datagen.fit(Y, augment=True, seed=seed) image_generator = image_datagen.flow(X, batch_size=batch, seed=seed) mask_generator = mask_datagen.flow(Y, batch_size=batch, seed=seed) generator = zip(image_generator, mask_generator) # validation image_datagen_val = ImageDataGenerator(**data_gen_args) mask_datagen_val = ImageDataGenerator(**data_gen_args) image_datagen_val.fit(X_val, augment=True, seed=seed) mask_datagen_val.fit(Y_val, augment=True, seed=seed) image_generator_val = image_datagen_val.flow(X_val, batch_size=batch, seed=seed) mask_generator_val = mask_datagen_val.flow(Y_val, batch_size=batch, seed=seed) generator_val = zip(image_generator_val, mask_generator_val) # plot examples x, y = generator.__next__() x_val, y_val = generator_val.__next__() plt.figure(figsize=(16, 4)) plt.subplot(141) plt.imshow(x[0, :, :, 0]) plt.subplot(142) plt.imshow(y[0, :, :, 0]) plt.subplot(143) plt.imshow(x_val[0, :, :, 0]) plt.subplot(144) plt.imshow(y_val[0, :, :, 0]) import os blocks = 2 channels = 16 learning_rate = 0.0004 learning_rate_decay_factor = 0.95 epoch_size_multiplicator = 20 # basically, how often do we decrease learning rate epochs = 10 #comment='_side' # adds to model_name import datetime #model path model_path = f'models/CSBDeep/model.h5' if os.path.isfile(model_path): print('Your model will be overwritten in the next cell') kernel_size = 3 from csbdeep.models import Config, CARE from keras import backend as K best_mae = 1 steps_per_epoch = len(X) * epoch_size_multiplicator validation_steps = len(X_val) * epoch_size_multiplicator if 'model' in globals(): del model if os.path.isfile(model_path): os.remove(model_path) for i in range(epochs): print('Epoch:', i + 1) learning_rate = learning_rate * learning_rate_decay_factor config = Config(axes, n_channel_in, n_channel_out, unet_kern_size=kernel_size, train_learning_rate=learning_rate, unet_n_depth=blocks, unet_n_first=channels) model = CARE(config, '.', basedir='models') #os.remove('models/config.json') if i > 0: model.keras_model.load_weights(model_path) model.prepare_for_training() history = model.keras_model.fit_generator( generator, validation_data=generator_val, validation_steps=validation_steps, epochs=1, verbose=0, shuffle=True, steps_per_epoch=steps_per_epoch) if history.history['val_mae'][0] < best_mae: best_mae = history.history['val_mae'][0] if not os.path.exists('models/'): os.makedirs('models/') model.keras_model.save(model_path) print(f'Validation MAE:{best_mae:.3E}') del model K.clear_session() #model path model_path = 'models/CSBDeep/model.h5' config = Config(axes, n_channel_in, n_channel_out, unet_kern_size=kernel_size, train_learning_rate=learning_rate, unet_n_depth=blocks, unet_n_first=channels) model = CARE(config, '.', basedir='models') model.keras_model.load_weights(model_path) model.export_TF()