Exemplo n.º 1
0
    def create_patches(self, normalization=None):
        for ch in self.train_channels:
            n_images = len(
                list((pathlib.Path(self.out_dir) / "train_data" / "raw" /
                      "CH_{}".format(ch) / "GT").glob("*.tif")))
            print("-- Creating {} patches for channel: {}".format(
                n_images * self.n_patches_per_image, ch))
            raw_data = RawData.from_folder(
                basepath=pathlib.Path(self.out_dir) / "train_data" / "raw" /
                "CH_{}".format(ch),
                source_dirs=["low"],
                target_dir="GT",
                axes=self.axes,
            )

            if normalization is not None:
                X, Y, XY_axes = create_patches(
                    raw_data=raw_data,
                    patch_size=self.patch_size,
                    n_patches_per_image=self.n_patches_per_image,
                    save_file=self.get_training_patch_path() /
                    "CH_{}_training_patches.npz".format(ch),
                    verbose=False,
                    normalization=normalization,
                )
            else:

                X, Y, XY_axes = create_patches(
                    raw_data=raw_data,
                    patch_size=self.patch_size,
                    n_patches_per_image=self.n_patches_per_image,
                    save_file=self.get_training_patch_path() /
                    "CH_{}_training_patches.npz".format(ch),
                    verbose=False,
                )

            plt.figure(figsize=(16, 4))

            rand_sel = numpy.random.randint(low=0, high=len(X), size=6)
            plot_some(X[rand_sel, 0],
                      Y[rand_sel, 0],
                      title_list=[range(6)],
                      cmap="gray")

            plt.show()

        print("Done")
        return
Exemplo n.º 2
0
 def _create(img_size, img_axes, patch_size, patch_axes):
     X, Y, XYaxes = create_patches(
         raw_data=get_data(n_images, img_axes, img_size),
         patch_size=patch_size,
         patch_axes=patch_axes,
         n_patches_per_image=n_patches_per_image,
     )
     assert len(X) == n_images * n_patches_per_image
     assert np.allclose(X, Y, atol=1e-6)
     if patch_axes is not None:
         assert XYaxes == 'SC' + patch_axes.replace('C', '')
Exemplo n.º 3
0
 def _create(img_size,img_axes,patch_size,patch_axes):
     U,V = (rng.uniform(size=(n_images,)+img_size) for _ in range(2))
     X,Y,XYaxes = create_patches (
         raw_data            = RawData.from_arrays(U,V,img_axes),
         patch_size          = patch_size,
         patch_axes          = patch_axes,
         n_patches_per_image = n_patches_per_image,
         save_file           = save_file
     )
     (_X,_Y), val_data, _XYaxes = load_training_data(save_file,verbose=True)
     assert val_data is None
     assert _XYaxes[-1 if backend_channels_last else 1] == 'C'
     _X,_Y = (move_image_axes(u,fr=_XYaxes,to=XYaxes) for u in (_X,_Y))
     assert np.allclose(X,_X,atol=1e-6)
     assert np.allclose(Y,_Y,atol=1e-6)
     assert set(XYaxes) == set(_XYaxes)
     assert load_training_data(save_file,validation_split=0.5)[2] is not None
     assert all(len(x)==3 for x in load_training_data(save_file,n_images=3)[0])
Exemplo n.º 4
0
def generate_2D_patch_training_data(BaseDirectory,
                                    SaveNpzDirectory,
                                    SaveName,
                                    patch_size=(512, 512),
                                    n_patches_per_image=64,
                                    transforms=None):

    raw_data = RawData.from_folder(
        basepath=BaseDirectory,
        source_dirs=['Original'],
        target_dir='BinaryMask',
        axes='YX',
    )

    X, Y, XY_axes = create_patches(
        raw_data=raw_data,
        patch_size=patch_size,
        n_patches_per_image=n_patches_per_image,
        transforms=transforms,
        save_file=SaveNpzDirectory + SaveName,
    )
Exemplo n.º 5
0
def split_chunks(
    data: List[np.ndarray], n_chunks: int, psf_as: List[np.ndarray],
    psf_bs: List[np.ndarray]
) -> Tuple[np.ndarray, np.ndarray, str, np.ndarray, np.ndarray]:
    """
    Extract chunks from RawData and normalize them

    :param data: The data to extract the chunks form
    :param n_chunks: The number of chunks to extract
    :param psf_as: The list of PSF corresponding to the first channel of data
    :param psf_bs: The list of PSF corresponding to the second channel of data
    :return: (X chunks, Y chunks, axes, psf_as, psf_bs)
    """
    def _normalize(patches_x, patches_y, x, y, mask, channel):
        return patches_x, patches_y

    X, Y, axes = create_patches(data, (64, 64, 64, 2),
                                n_patches_per_image=n_chunks,
                                patch_filter=None,
                                shuffle=False)
    return X, Y, axes, np.repeat(psf_as, n_chunks, axis=0), np.repeat(psf_bs,
                                                                      n_chunks,
                                                                      axis=0),
Exemplo n.º 6
0
def data_generation(data_path, axes, patch_size, data_name):
    """Generates training data for training CARE network. `RawData` object defines how to get the pairs of low/high SNR stacks and the semantics of each axis. We have two folders "noisy" and "clean" where corresponding low and high-SNR stacks are TIFF images with identical filenames. 

	Parameters
	----------
	data_path : str
		Path of the input data containing 'noisy' and 'clean' folder
	axes : str
		Semantic order each axes
	patch_size : tuple 
		Size of the patches to crop the images
	data_name : str
		Name of the .npz file containing the pairs of images.
	
		
 
	"""
    raw_data = RawData.from_folder(
        basepath=data_path,
        source_dirs=['noisy'],
        target_dir='clean',
        axes=axes,
    )

    # Patch size that is a power of two along XYZT, or at least divisible by 8.
    # By convention, the variable name `X` (or `x`) refers to an input variable for a machine learning model, whereas `Y` (or `y`) indicates an output variable.

    X, Y, XY_axes = create_patches(
        raw_data=raw_data,
        patch_size=patch_size,
        n_patches_per_image=100,
        save_file='data/' + data_name + '.npz',
    )

    assert X.shape == Y.shape
    print("shape of X,Y =", X.shape)
    print("axes  of X,Y =", XY_axes)
Exemplo n.º 7
0
for image in tqdm(imgList, 'Reading img'):
    imgArray.append(imread(os.path.join(train_dir, image)))

labelArray = []
for label in tqdm(labelList, 'Reading label'):
    labelArray.append(imread(os.path.join(label_dir, label)))

print(imgArray[0].shape)
print(labelArray[0].shape)

raw_data = RawData.from_arrays(imgArray, labelArray, axes='YX')

X, Y, XY_axes = create_patches(
    raw_data=raw_data,
    patch_size=(128, 128, 1),
    patch_axes='YXC',
    n_patches_per_image=25,
    save_file=
    '/mnt/AE3205C73205958D/Data/3dliver_local/pc_adult/2d_slices/imagesXY/image_full/mydata_128x128patch.npz'
)

(X, Y), (X_val, Y_val), axes = load_training_data(
    '/mnt/AE3205C73205958D/Data/3dliver_local/pc_adult/2d_slices/imagesXY/image_full/mydata_128x128patch.npz',
    validation_split=0.1,
    verbose=True)

n = 10
print(X[n].shape)
fig = plt.figure()
pl1_x = fig.add_subplot(2, 5, 1)
pl1_x.imshow(X[n][..., 0])
pl1_y = fig.add_subplot(2, 5, 6)
Exemplo n.º 8
0
def train(Training_source=".",
          Training_target=".",
          model_name="No_name",
          model_path=".",
          Visual_validation_after_training=True,
          number_of_epochs=100,
          patch_size=64,
          number_of_patches=10,
          Use_Default_Advanced_Parameters=True,
          number_of_steps=300,
          batch_size=32,
          percentage_validation=15):
    '''
  Main function of the script. Train the model an save in model_path

  Parameters
  ----------
  Training_source : (str) Path to the noisy images
  Training_target : (str) Path to the GT images
  model_name : (str) name of the model
  model_path : (str) path of the model
  Visual_validation_after_training : (bool) Predict a random image after training
  Number_of_epochs : (int) epochs
  path_size : (int) patch sizes
  number_of_patches : (int) number of patches for each image
  User_Default_Advances_Parameters : (bool) Use default parameters for the training
  number_of_steps : (int) number of steps
  batch_size : (int) batch size
  percentage_validation : (int) percentage validation

  Return
  -------
  void
  '''
    OutputFile = Training_target + "/*.tif"
    InputFile = Training_source + "/*.tif"
    base = "/content/"
    training_data = base + "/my_training_data.npz"
    if (Use_Default_Advanced_Parameters):
        print("Default advanced parameters enabled")
        batch_size = 64
        percentage_validation = 10

    percentage = percentage_validation / 100
    #here we check that no model with the same name already exist, if so delete
    if os.path.exists(model_path + '/' + model_name):
        shutil.rmtree(model_path + '/' + model_name)

    # The shape of the images.
    x = imread(InputFile)
    y = imread(OutputFile)

    print('Loaded Input images (number, width, length) =', x.shape)
    print('Loaded Output images (number, width, length) =', y.shape)
    print("Parameters initiated.")

    # RawData Object

    # This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.
    # This file is saved in .npz format and later called when loading the trainig data.

    raw_data = data.RawData.from_folder(basepath=base,
                                        source_dirs=[Training_source],
                                        target_dir=Training_target,
                                        axes='CYX',
                                        pattern='*.tif*')

    X, Y, XY_axes = data.create_patches(raw_data,
                                        patch_filter=None,
                                        patch_size=(patch_size, patch_size),
                                        n_patches_per_image=number_of_patches)

    print('Creating 2D training dataset')
    training_path = model_path + "/rawdata"
    rawdata1 = training_path + ".npz"
    np.savez(training_path, X=X, Y=Y, axes=XY_axes)

    # Load Training Data
    (X, Y), (X_val,
             Y_val), axes = load_training_data(rawdata1,
                                               validation_split=percentage,
                                               verbose=True)
    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]
    #Show_patches(X,Y)

    #Here we automatically define number_of_step in function of training data and batch size
    if (Use_Default_Advanced_Parameters):
        number_of_steps = int(X.shape[0] / batch_size) + 1

    print(number_of_steps)

    #Here we create the configuration file

    config = Config(axes,
                    n_channel_in,
                    n_channel_out,
                    probabilistic=False,
                    train_steps_per_epoch=number_of_steps,
                    train_epochs=number_of_epochs,
                    unet_kern_size=5,
                    unet_n_depth=3,
                    train_batch_size=batch_size,
                    train_learning_rate=0.0004)

    print(config)
    vars(config)

    # Compile the CARE model for network training
    model_training = CARE(config, model_name, basedir=model_path)

    if (Visual_validation_after_training):
        Cell_executed = 1

    import time
    start = time.time()

    #@markdown ##Start Training

    # Start Training
    history = model_training.train(X, Y, validation_data=(X_val, Y_val))

    print("Training, done.")
    if (Visual_validation_after_training):
        Predict_a_image(Training_source, Training_target, model_path,
                        model_training)

    # Displaying the time elapsed for training
    dt = time.time() - start
    min, sec = divmod(dt, 60)
    hour, min = divmod(min, 60)
    print("Time elapsed:", hour, "hour(s)", min, "min(s)", round(sec),
          "sec(s)")

    Show_loss_function(history, model_path)
Exemplo n.º 9
0
    basepath=
    '/run/user/1000/gvfs/smb-share:server=isiserver.curie.net,share=u934/equipe_bellaiche/l_sancere/Training_Data_Sets/Training_CARE_restoration/SpinwideFRAP4_Training_CARE_40x_bin2_reduced',
    source_dirs=['Low'],
    target_dir='GT',
    axes='ZYX',
    pattern='*.TIF')

# In[3]:

patch_size = (16, 64, 64)
n_patches_per_image = 64

X, Y, XY_axes = create_patches(
    raw_data=raw_data,
    patch_size=
    patch_size,  #for bin1 it is 16 128 128 and for bin2 it is 16 64 64
    n_patches_per_image=n_patches_per_image,  #at least 64? 
    save_file=
    '/run/media/sancere/DATA1/Lucas_NextonCreated_npz/Training_CARE_restoration_SpinwideFRAP4_Bin2_Reduced.npz',
)

# In[4]:

ConfigNPZ = open(
    "/run/media/sancere/DATA1/Lucas_NextonCreated_npz/Parameters_Npz/ConfigNPZ_Training_CARE_restoration_SpinwideFRAP4_Bin2.txt",
    "w+")
ConfigNPZ.write("patch_size = {} \n n_patches_per_image = {}".format(
    patch_size, n_patches_per_image))
ConfigNPZ.close()

# In[5]:
Exemplo n.º 10
0
    def Train(self):

        BinaryName = 'BinaryMask/'
        RealName = 'RealMask/'
        Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif'))
        Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True)
        Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True)
        RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif'))
        ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif'))
        ValRealMask = sorted(
            glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif'))

        print('Instance segmentation masks:', len(RealMask))
        if len(RealMask) == 0:

            print('Making labels')
            Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))

            for fname in Mask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = label(image)

                imwrite((self.BaseDir + '/' + RealName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))
        print('Semantic segmentation masks:', len(Mask))
        if len(Mask) == 0:
            print('Generating Binary images')

            RealfilesMask = sorted(
                glob.glob(self.BaseDir + '/' + RealName + '*tif'))

            for fname in RealfilesMask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = image > 0

                imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        if self.GenerateNPZ:

            raw_data = RawData.from_folder(
                basepath=self.BaseDir,
                source_dirs=['Raw/'],
                target_dir='BinaryMask/',
                axes='ZYX',
            )

            X, Y, XY_axes = create_patches(
                raw_data=raw_data,
                patch_size=(self.PatchZ, self.PatchY, self.PatchX),
                n_patches_per_image=self.n_patches_per_image,
                save_file=self.BaseDir + self.NPZfilename + '.npz',
            )

        # Training UNET model
        if self.TrainUNET:
            print('Training UNET model')
            load_path = self.BaseDir + self.NPZfilename + '.npz'

            (X, Y), (X_val,
                     Y_val), axes = load_training_data(load_path,
                                                       validation_split=0.1,
                                                       verbose=True)
            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_n_depth=self.depth,
                            train_epochs=self.epochs,
                            train_batch_size=self.batch_size,
                            unet_n_first=self.startfilter,
                            train_loss='mse',
                            unet_kern_size=self.kern_size,
                            train_learning_rate=self.learning_rate,
                            train_reduce_lr={
                                'patience': 5,
                                'factor': 0.5
                            })
            print(config)
            vars(config)

            model = CARE(config,
                         name='UNET' + self.model_name,
                         basedir=self.model_dir)

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + 'UNET' +
                                  self.copy_model_name + '/' +
                                  'weights_now.h5') and os.path.exists(
                                      self.model_dir + 'UNET' +
                                      self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    model.load_weights(self.copy_model_dir + 'UNET' +
                                       self.copy_model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_last.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_best.h5')

            history = model.train(X, Y, validation_data=(X_val, Y_val))

            print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

        if self.TrainSTAR:
            print('Training StarDistModel model with', self.backbone,
                  'backbone')
            self.axis_norm = (0, 1, 2)
            if self.CroppedLoad == False:
                assert len(Raw) > 1, "not enough training data"
                print(len(Raw))
                rng = np.random.RandomState(42)
                ind = rng.permutation(len(Raw))

                X_train = list(map(ReadFloat, Raw))
                Y_train = list(map(ReadInt, RealMask))
                self.Y = [
                    label(DownsampleData(y, self.DownsampleFactor))
                    for y in tqdm(Y_train)
                ]
                self.X = [
                    normalize(DownsampleData(x, self.DownsampleFactor),
                              1,
                              99.8,
                              axis=self.axis_norm) for x in tqdm(X_train)
                ]
                n_val = max(1, int(round(0.15 * len(ind))))
                ind_train, ind_val = ind[:-n_val], ind[-n_val:]

                self.X_val, self.Y_val = [self.X[i] for i in ind_val
                                          ], [self.Y[i] for i in ind_val]
                self.X_trn, self.Y_trn = [self.X[i] for i in ind_train
                                          ], [self.Y[i] for i in ind_train]

                print('number of images: %3d' % len(self.X))
                print('- training:       %3d' % len(self.X_trn))
                print('- validation:     %3d' % len(self.X_val))

            if self.CroppedLoad:
                self.X_trn = self.DataSequencer(Raw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_trn = self.DataSequencer(RealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)

                self.X_val = self.DataSequencer(ValRaw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_val = self.DataSequencer(ValRealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)
                self.train_sample_cache = False

            print(Config3D.__doc__)

            anisotropy = (1, 1, 1)
            rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy)

            if self.backbone == 'resnet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    resnet_n_blocks=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    resnet_kernel_size=(self.kern_size, self.kern_size,
                                        self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    resnet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1)

            if self.backbone == 'unet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    unet_n_depth=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    unet_kernel_size=(self.kern_size, self.kern_size,
                                      self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    unet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1,
                    train_sample_cache=False)

            print(conf)
            vars(conf)

            Starmodel = StarDist3D(conf,
                                   name=self.model_name,
                                   basedir=self.model_dir)
            print(
                Starmodel._axes_tile_overlap('ZYX'),
                os.path.exists(self.model_dir + self.model_name + '/' +
                               'weights_now.h5'))

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_now.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_now.h5')
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_last.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_last.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_last.h5')

                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_best.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_best.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_best.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_last.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_best.h5')

            historyStar = Starmodel.train(self.X_trn,
                                          self.Y_trn,
                                          validation_data=(self.X_val,
                                                           self.Y_val),
                                          epochs=self.epochs)
            print(sorted(list(historyStar.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(historyStar, ['loss', 'val_loss'], [
                'dist_relevant_mae', 'val_dist_relevant_mae',
                'dist_relevant_mse', 'val_dist_relevant_mse'
            ])
#     axes        = 'CYX',
# )

raw_data = RawData.from_arrays(X=[x], Y=[x], axes='CYX')

anisotropic_transform = anisotropic_distortions(
    subsample=1,
    psf=np.ones((3)) / 9,  # use the actual PSF here
    psf_axes='Y',
    # poisson_noise = True,
    # gauss_sigma = 0.1
)

X, Y, XY_axes = create_patches(
    raw_data=raw_data,
    patch_size=(x.shape[0], x.shape[1], x.shape[2]),
    n_patches_per_image=1,
    transforms=[anisotropic_transform],
)
# z = axes_dict(XY_axes)['Z']
# X = np.take(X,0,axis=z)
# Y = np.take(Y,0,axis=z)
# XY_axes = XY_axes.replace('Z','')
# assert X.shape == Y.shape
# print("shape of X,Y =", X.shape)
# print("axes  of X,Y =", XY_axes)
# plt.figure(figsize=(389/100, 389/100))
for i in range(len(X)):
    fig = plt.figure(frameon=False)
    fig.set_size_inches(X[i].shape[1] / 200, X[i].shape[0] / 200)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
Exemplo n.º 12
0
from argparse import ArgumentParser
from csbdeep.data import RawData, create_patches

args = ArgumentParser()
args.add_argument("input_basepath")
args.add_argument("input_x")
args.add_argument("input_y")
args.add_argument("output_file")
args.add_argument("--patch_size_xy", default=32)
args.add_argument("--patch_size_z", default=16)
args.add_argument("--n_patches_per_image", default=750)
args.add_argument("--axes", default="ZYX")
args = args.parse_args()

data = RawData.from_folder(basepath=args.input_basepath,
                           source_dirs=[args.input_x],
                           target_dir=args.input_y,
                           axes=args.axes)

ps_xy = int(args.patch_size_xy)
ps_z = int(args.patch_size_z)

X, Y, axes = create_patches(data,
                            patch_size=(ps_z, ps_xy, ps_xy, 1),
                            n_patches_per_image=int(args.n_patches_per_image),
                            save_file=args.output_file,
                            patch_axes=args.axes + "C")
Exemplo n.º 13
0
def original():
    x = np.load('data/img006_noconv.npy')
    x = np.moveaxis(x, 1, 2)
    ## initial axes are TZCYX, but we
    axes = 'CZYX'

    mypath = Path('isonet_psf_1')
    mypath.mkdir(exist_ok=True)

    subsample = 5.0  #10.2
    print('image size       =', x.shape)
    print('image axes       =', axes)
    print('subsample factor =', subsample)

    plt.switch_backend('agg')

    if False:
        plt.figure(figsize=(15, 15))
        plot_some(np.moveaxis(x[0], 1, -1)[[5, -5]],
                  title_list=[['xy slice', 'xy slice']],
                  pmin=2,
                  pmax=99.8)
        plt.savefig(mypath / 'datagen_1.png')

        plt.figure(figsize=(15, 15))
        plot_some(np.moveaxis(np.moveaxis(x[0], 1, -1)[:, [50, -50]], 1, 0),
                  title_list=[['xz slice', 'xz slice']],
                  pmin=2,
                  pmax=99.8,
                  aspect=subsample)
        plt.savefig(mypath / 'datagen_2.png')

    def gimmeit_gen():
        ## iterate over time dimension
        for i in range(x.shape[0]):
            yield x[i], x[i], axes, None

    raw_data = RawData(gimmeit_gen, x.shape[0], "this is great!")

    ## initial idea
    if False:

        def buildkernel():
            kern = np.exp(-(np.arange(10)**2 / 2))
            kern /= kern.sum()
            kern = kern.reshape([1, 1, -1, 1])
            kern = np.stack([kern, kern], axis=1)
            return kern

        psf_kern = buildkernel()

    ## use Martins theoretical psf
    if False:
        psf_aniso = imread('data/psf_aniso_NA_0.8.tif')
        psf_channels = np.stack([
            psf_aniso,
        ] * 2, axis=1)

    def buildkernel():
        kernel = np.zeros(20)
        kernel[7:13] = 1 / 6
        ## reshape into CZYX. long axis is X.
        kernel = kernel.reshape([1, 1, -1])
        ## repeate same kernel for both channels
        kernel = np.stack([kernel, kernel], axis=0)
        return kernel

    psf = buildkernel()
    print(psf.shape)

    ## use theoretical psf
    if False:
        psf_channels = np.load('data/measured_psfs.npy')
        psf = rotate(psf_channels, 90, axes=(1, 3))

    iso_transform = data.anisotropic_distortions(
        subsample=subsample,
        psf=psf,
    )

    X, Y, XY_axes = data.create_patches(
        raw_data=raw_data,
        patch_size=(2, 1, 128, 128),
        n_patches_per_image=256,
        transforms=[iso_transform],
    )

    assert X.shape == Y.shape
    print("shape of X,Y =", X.shape)
    print("axes  of X,Y =", XY_axes)

    # remove dummy z dim to obtain multi-channel 2D patches
    X = X[:, :, 0, ...]
    Y = Y[:, :, 0, ...]
    XY_axes = XY_axes.replace('Z', '')

    assert X.shape == Y.shape
    print("shape of X,Y =", X.shape)
    print("axes  of X,Y =", XY_axes)

    np.savez(mypath / 'my_training_data.npz', X=X, Y=Y, axes=XY_axes)

    for i in range(2):
        plt.figure(figsize=(16, 4))
        sl = slice(8 * i, 8 * (i + 1))
        plot_some(np.moveaxis(X[sl], 1, -1),
                  np.moveaxis(Y[sl], 1, -1),
                  title_list=[np.arange(sl.start, sl.stop)])
        plt.savefig(mypath / 'datagen_panel_{}.png'.format(i))
        y = imread(
            os.path.join(args.base_dir, "high_snr", f"img_{img_number}.tif"))
        print(f"image size: {x.shape}")
        plt.figure(figsize=(16, 10))
        plot_some(
            np.stack([x, y]),
            title_list=[['low snr', 'high snr']],
        )
        plt.show()

    if not os.path.exists(os.path.join(args.base_dir, "test")):
        os.mkdir(os.path.join(args.base_dir, "test"))
    shutil.move(
        os.path.join(args.base_dir, "low_snr", f"img_{img_number}.tif"),
        os.path.join(args.base_dir, "test", "low_snr.tif"))
    shutil.move(
        os.path.join(args.base_dir, "high_snr", f"img_{img_number}.tif"),
        os.path.join(args.base_dir, "test", "high_snr.tif"))

    # Read the pairs, passing in the axis semantics
    raw_data = RawData.from_folder(basepath=args.base_dir,
                                   source_dirs=['low_snr'],
                                   target_dir='high_snr',
                                   axes='YX')

    # From the stacks, generate 2D patches, and save to the output_filename
    create_patches(raw_data=raw_data,
                   patch_size=(128, 128),
                   n_patches_per_image=512,
                   save_file=os.path.join(args.base_dir, args.output_filename))
Exemplo n.º 15
0
    def training(self):
        #Loading Files and Plot examples
        basepath = 'data/'
        training_original_dir = 'training/original/'
        training_ground_truth_dir = 'training/ground_truth/'
        validation_original_dir = 'validation/original/'
        validation_ground_truth_dir = 'validation/ground_truth/'
        import glob
        from skimage import io
        from matplotlib import pyplot as plt
        training_original_files = sorted(
            glob.glob(basepath + training_original_dir + '*.tif'))
        training_original_file = io.imread(training_original_files[0])
        training_ground_truth_files = sorted(
            glob.glob(basepath + training_ground_truth_dir + '*.tif'))
        training_ground_truth_file = io.imread(training_ground_truth_files[0])
        print("Training dataset's number of files and dimensions: ",
              len(training_original_files), training_original_file.shape,
              len(training_ground_truth_files),
              training_ground_truth_file.shape)
        training_size = len(training_original_file)

        validation_original_files = sorted(
            glob.glob(basepath + validation_original_dir + '*.tif'))
        validation_original_file = io.imread(validation_original_files[0])
        validation_ground_truth_files = sorted(
            glob.glob(basepath + validation_ground_truth_dir + '*.tif'))
        validation_ground_truth_file = io.imread(
            validation_ground_truth_files[0])
        print("Validation dataset's number of files and dimensions: ",
              len(validation_original_files), validation_original_file.shape,
              len(validation_ground_truth_files),
              validation_ground_truth_file.shape)
        validation_size = len(validation_original_file)

        if training_size == validation_size:
            size = training_size
        else:
            print(
                'Training and validation images should be of the same dimensions!'
            )

        plt.figure(figsize=(16, 4))
        plt.subplot(141)
        plt.imshow(training_original_file)
        plt.subplot(142)
        plt.imshow(training_ground_truth_file)
        plt.subplot(143)
        plt.imshow(validation_original_file)
        plt.subplot(144)
        plt.imshow(validation_ground_truth_file)

        #preparing inputs for NN  from pairs of 32bit TIFF image files with intensities in range 0..1. . Run it only once for each new dataset
        from csbdeep.data import RawData, create_patches, no_background_patches
        training_data = RawData.from_folder(
            basepath=basepath,
            source_dirs=[training_original_dir],
            target_dir=training_ground_truth_dir,
            axes='YX',
        )

        validation_data = RawData.from_folder(
            basepath=basepath,
            source_dirs=[validation_original_dir],
            target_dir=validation_ground_truth_dir,
            axes='YX',
        )

        # pathces will be created further in "data augmentation" step,
        # that's why patch size here is the dimensions of images and number of pathes per image is 1
        size1 = 64
        X, Y, XY_axes = create_patches(
            raw_data=training_data,
            patch_size=(size1, size1),
            patch_filter=no_background_patches(0),
            n_patches_per_image=1,
            save_file=basepath + 'training.npz',
        )

        X_val, Y_val, XY_axes = create_patches(
            raw_data=validation_data,
            patch_size=(size1, size1),
            patch_filter=no_background_patches(0),
            n_patches_per_image=1,
            save_file=basepath + 'validation.npz',
        )

        #Loading training and validation data into memory
        from csbdeep.io import load_training_data
        (X, Y), _, axes = load_training_data(basepath + 'training.npz',
                                             verbose=False)
        (X_val,
         Y_val), _, axes = load_training_data(basepath + 'validation.npz',
                                              verbose=False)
        X.shape, Y.shape, X_val.shape, Y_val.shape
        from csbdeep.utils import axes_dict
        c = axes_dict(axes)['C']
        n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

        batch = len(
            X
        )  # You should define number of batches according to the available memory
        #batch=1
        seed = 1
        from keras.preprocessing.image import ImageDataGenerator
        data_gen_args = dict(samplewise_center=False,
                             samplewise_std_normalization=False,
                             zca_whitening=False,
                             fill_mode='reflect',
                             rotation_range=30,
                             shear_range=0.2,
                             zoom_range=0.2,
                             horizontal_flip=True,
                             vertical_flip=True
                             #width_shift_range=0.2,
                             #height_shift_range=0.2,
                             )

        # training
        image_datagen = ImageDataGenerator(**data_gen_args)
        mask_datagen = ImageDataGenerator(**data_gen_args)
        image_datagen.fit(X, augment=True, seed=seed)
        mask_datagen.fit(Y, augment=True, seed=seed)
        image_generator = image_datagen.flow(X, batch_size=batch, seed=seed)
        mask_generator = mask_datagen.flow(Y, batch_size=batch, seed=seed)
        generator = zip(image_generator, mask_generator)

        # validation
        image_datagen_val = ImageDataGenerator(**data_gen_args)
        mask_datagen_val = ImageDataGenerator(**data_gen_args)
        image_datagen_val.fit(X_val, augment=True, seed=seed)
        mask_datagen_val.fit(Y_val, augment=True, seed=seed)
        image_generator_val = image_datagen_val.flow(X_val,
                                                     batch_size=batch,
                                                     seed=seed)
        mask_generator_val = mask_datagen_val.flow(Y_val,
                                                   batch_size=batch,
                                                   seed=seed)
        generator_val = zip(image_generator_val, mask_generator_val)

        # plot examples
        x, y = generator.__next__()
        x_val, y_val = generator_val.__next__()

        plt.figure(figsize=(16, 4))
        plt.subplot(141)
        plt.imshow(x[0, :, :, 0])
        plt.subplot(142)
        plt.imshow(y[0, :, :, 0])
        plt.subplot(143)
        plt.imshow(x_val[0, :, :, 0])
        plt.subplot(144)
        plt.imshow(y_val[0, :, :, 0])

        import os
        blocks = 2
        channels = 16
        learning_rate = 0.0004
        learning_rate_decay_factor = 0.95
        epoch_size_multiplicator = 20  # basically, how often do we decrease learning rate
        epochs = 10
        #comment='_side' # adds to model_name
        import datetime
        #model path
        model_path = f'models/CSBDeep/model.h5'
        if os.path.isfile(model_path):
            print('Your model will be overwritten in the next cell')
        kernel_size = 3

        from csbdeep.models import Config, CARE
        from keras import backend as K
        best_mae = 1
        steps_per_epoch = len(X) * epoch_size_multiplicator
        validation_steps = len(X_val) * epoch_size_multiplicator
        if 'model' in globals():
            del model
        if os.path.isfile(model_path):
            os.remove(model_path)
        for i in range(epochs):
            print('Epoch:', i + 1)
            learning_rate = learning_rate * learning_rate_decay_factor
            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_kern_size=kernel_size,
                            train_learning_rate=learning_rate,
                            unet_n_depth=blocks,
                            unet_n_first=channels)
            model = CARE(config, '.', basedir='models')
            #os.remove('models/config.json')
            if i > 0:
                model.keras_model.load_weights(model_path)
            model.prepare_for_training()
            history = model.keras_model.fit_generator(
                generator,
                validation_data=generator_val,
                validation_steps=validation_steps,
                epochs=1,
                verbose=0,
                shuffle=True,
                steps_per_epoch=steps_per_epoch)
            if history.history['val_mae'][0] < best_mae:
                best_mae = history.history['val_mae'][0]
                if not os.path.exists('models/'):
                    os.makedirs('models/')
                model.keras_model.save(model_path)
                print(f'Validation MAE:{best_mae:.3E}')
            del model
            K.clear_session()

        #model path
        model_path = 'models/CSBDeep/model.h5'
        config = Config(axes,
                        n_channel_in,
                        n_channel_out,
                        unet_kern_size=kernel_size,
                        train_learning_rate=learning_rate,
                        unet_n_depth=blocks,
                        unet_n_first=channels)
        model = CARE(config, '.', basedir='models')
        model.keras_model.load_weights(model_path)
        model.export_TF()