Exemplo n.º 1
0
 def get_data(n_images, axes, shape):
     def _gen():
         for i in range(n_images):
             x = rng.uniform(size=shape)
             y = 5 + 3*x
             yield x, y, axes, None
     return RawData(_gen, n_images, '')
Exemplo n.º 2
0
    def get_data(n_images, axes, shape):
        red_n = rng.choice(len(axes) - 1) + 1
        red_axes = ''.join(rng.choice(tuple(axes), red_n, replace=False))
        keepdims = rng.choice((True, False))

        def _gen():
            for i in range(n_images):
                x = rng.uniform(size=shape)
                y = np.mean(x,
                            axis=tuple(axes_dict(axes)[a] for a in red_axes),
                            keepdims=keepdims)
                yield x, y, axes, None

        return RawData(_gen, n_images, ''), red_axes, keepdims
Exemplo n.º 3
0
def gen_raw_data(n: int, **kwargs):
    """
    Generate a csbdeep RawData object with random images generated by :func:`my text
    <dispim.neural.datagen.gen_training_data>`

    :param n:
    :param use_noise:
    :param use_psf:
    :param use_subsampling:
    :param shape:
    :return:
    """
    images_degr, images, psf_as, psf_bs = gen_training_data(n, **kwargs)
    return RawData.from_arrays(images_degr, images,
                               axes='XYZC'), psf_as, psf_bs
Exemplo n.º 4
0
    def create_patches(self, normalization=None):
        for ch in self.train_channels:
            n_images = len(
                list((pathlib.Path(self.out_dir) / "train_data" / "raw" /
                      "CH_{}".format(ch) / "GT").glob("*.tif")))
            print("-- Creating {} patches for channel: {}".format(
                n_images * self.n_patches_per_image, ch))
            raw_data = RawData.from_folder(
                basepath=pathlib.Path(self.out_dir) / "train_data" / "raw" /
                "CH_{}".format(ch),
                source_dirs=["low"],
                target_dir="GT",
                axes=self.axes,
            )

            if normalization is not None:
                X, Y, XY_axes = create_patches(
                    raw_data=raw_data,
                    patch_size=self.patch_size,
                    n_patches_per_image=self.n_patches_per_image,
                    save_file=self.get_training_patch_path() /
                    "CH_{}_training_patches.npz".format(ch),
                    verbose=False,
                    normalization=normalization,
                )
            else:

                X, Y, XY_axes = create_patches(
                    raw_data=raw_data,
                    patch_size=self.patch_size,
                    n_patches_per_image=self.n_patches_per_image,
                    save_file=self.get_training_patch_path() /
                    "CH_{}_training_patches.npz".format(ch),
                    verbose=False,
                )

            plt.figure(figsize=(16, 4))

            rand_sel = numpy.random.randint(low=0, high=len(X), size=6)
            plot_some(X[rand_sel, 0],
                      Y[rand_sel, 0],
                      title_list=[range(6)],
                      cmap="gray")

            plt.show()

        print("Done")
        return
Exemplo n.º 5
0
 def _create(img_size,img_axes,patch_size,patch_axes):
     U,V = (rng.uniform(size=(n_images,)+img_size) for _ in range(2))
     X,Y,XYaxes = create_patches (
         raw_data            = RawData.from_arrays(U,V,img_axes),
         patch_size          = patch_size,
         patch_axes          = patch_axes,
         n_patches_per_image = n_patches_per_image,
         save_file           = save_file
     )
     (_X,_Y), val_data, _XYaxes = load_training_data(save_file,verbose=True)
     assert val_data is None
     assert _XYaxes[-1 if backend_channels_last else 1] == 'C'
     _X,_Y = (move_image_axes(u,fr=_XYaxes,to=XYaxes) for u in (_X,_Y))
     assert np.allclose(X,_X,atol=1e-6)
     assert np.allclose(Y,_Y,atol=1e-6)
     assert set(XYaxes) == set(_XYaxes)
     assert load_training_data(save_file,validation_split=0.5)[2] is not None
     assert all(len(x)==3 for x in load_training_data(save_file,n_images=3)[0])
Exemplo n.º 6
0
def test_rawdata_from_folder(tmpdir):
    rng = np.random.RandomState(42)
    tmpdir = Path(str(tmpdir))

    n_images, img_size, img_axes = 3, (64,64), 'YX'
    data = {'X' : rng.uniform(size=(n_images,)+img_size).astype(np.float32),
            'Y' : rng.uniform(size=(n_images,)+img_size).astype(np.float32)}

    for name,images in data.items():
        (tmpdir/name).mkdir(exist_ok=True)
        for i,img in enumerate(images):
            imsave(str(tmpdir/name/('img_%02d.tif'%i)),img)

    raw_data = RawData.from_folder(str(tmpdir),['X'],'Y',img_axes)
    assert raw_data.size == n_images
    for i,(x,y,axes,mask) in enumerate(raw_data.generator()):
        assert mask is None
        assert axes == img_axes
        assert any(np.allclose(x,u) for u in data['X'])
        assert any(np.allclose(y,u) for u in data['Y'])
Exemplo n.º 7
0
def generate_2D_patch_training_data(BaseDirectory,
                                    SaveNpzDirectory,
                                    SaveName,
                                    patch_size=(512, 512),
                                    n_patches_per_image=64,
                                    transforms=None):

    raw_data = RawData.from_folder(
        basepath=BaseDirectory,
        source_dirs=['Original'],
        target_dir='BinaryMask',
        axes='YX',
    )

    X, Y, XY_axes = create_patches(
        raw_data=raw_data,
        patch_size=patch_size,
        n_patches_per_image=n_patches_per_image,
        transforms=transforms,
        save_file=SaveNpzDirectory + SaveName,
    )
Exemplo n.º 8
0
def data_generation(data_path, axes, patch_size, data_name):
    """Generates training data for training CARE network. `RawData` object defines how to get the pairs of low/high SNR stacks and the semantics of each axis. We have two folders "noisy" and "clean" where corresponding low and high-SNR stacks are TIFF images with identical filenames. 

	Parameters
	----------
	data_path : str
		Path of the input data containing 'noisy' and 'clean' folder
	axes : str
		Semantic order each axes
	patch_size : tuple 
		Size of the patches to crop the images
	data_name : str
		Name of the .npz file containing the pairs of images.
	
		
 
	"""
    raw_data = RawData.from_folder(
        basepath=data_path,
        source_dirs=['noisy'],
        target_dir='clean',
        axes=axes,
    )

    # Patch size that is a power of two along XYZT, or at least divisible by 8.
    # By convention, the variable name `X` (or `x`) refers to an input variable for a machine learning model, whereas `Y` (or `y`) indicates an output variable.

    X, Y, XY_axes = create_patches(
        raw_data=raw_data,
        patch_size=patch_size,
        n_patches_per_image=100,
        save_file='data/' + data_name + '.npz',
    )

    assert X.shape == Y.shape
    print("shape of X,Y =", X.shape)
    print("axes  of X,Y =", XY_axes)
Exemplo n.º 9
0
imgList = os.listdir(train_dir)
labelList = os.listdir(label_dir)

imgArray = []
for image in tqdm(imgList, 'Reading img'):
    imgArray.append(imread(os.path.join(train_dir, image)))

labelArray = []
for label in tqdm(labelList, 'Reading label'):
    labelArray.append(imread(os.path.join(label_dir, label)))

print(imgArray[0].shape)
print(labelArray[0].shape)

raw_data = RawData.from_arrays(imgArray, labelArray, axes='YX')

X, Y, XY_axes = create_patches(
    raw_data=raw_data,
    patch_size=(128, 128, 1),
    patch_axes='YXC',
    n_patches_per_image=25,
    save_file=
    '/mnt/AE3205C73205958D/Data/3dliver_local/pc_adult/2d_slices/imagesXY/image_full/mydata_128x128patch.npz'
)

(X, Y), (X_val, Y_val), axes = load_training_data(
    '/mnt/AE3205C73205958D/Data/3dliver_local/pc_adult/2d_slices/imagesXY/image_full/mydata_128x128patch.npz',
    validation_split=0.1,
    verbose=True)
Exemplo n.º 10
0
    def Train(self):

        BinaryName = 'BinaryMask/'
        RealName = 'RealMask/'
        Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif'))
        Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True)
        Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True)
        RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif'))
        ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif'))
        ValRealMask = sorted(
            glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif'))

        print('Instance segmentation masks:', len(RealMask))
        if len(RealMask) == 0:

            print('Making labels')
            Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))

            for fname in Mask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = label(image)

                imwrite((self.BaseDir + '/' + RealName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))
        print('Semantic segmentation masks:', len(Mask))
        if len(Mask) == 0:
            print('Generating Binary images')

            RealfilesMask = sorted(
                glob.glob(self.BaseDir + '/' + RealName + '*tif'))

            for fname in RealfilesMask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = image > 0

                imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        if self.GenerateNPZ:

            raw_data = RawData.from_folder(
                basepath=self.BaseDir,
                source_dirs=['Raw/'],
                target_dir='BinaryMask/',
                axes='ZYX',
            )

            X, Y, XY_axes = create_patches(
                raw_data=raw_data,
                patch_size=(self.PatchZ, self.PatchY, self.PatchX),
                n_patches_per_image=self.n_patches_per_image,
                save_file=self.BaseDir + self.NPZfilename + '.npz',
            )

        # Training UNET model
        if self.TrainUNET:
            print('Training UNET model')
            load_path = self.BaseDir + self.NPZfilename + '.npz'

            (X, Y), (X_val,
                     Y_val), axes = load_training_data(load_path,
                                                       validation_split=0.1,
                                                       verbose=True)
            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_n_depth=self.depth,
                            train_epochs=self.epochs,
                            train_batch_size=self.batch_size,
                            unet_n_first=self.startfilter,
                            train_loss='mse',
                            unet_kern_size=self.kern_size,
                            train_learning_rate=self.learning_rate,
                            train_reduce_lr={
                                'patience': 5,
                                'factor': 0.5
                            })
            print(config)
            vars(config)

            model = CARE(config,
                         name='UNET' + self.model_name,
                         basedir=self.model_dir)

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + 'UNET' +
                                  self.copy_model_name + '/' +
                                  'weights_now.h5') and os.path.exists(
                                      self.model_dir + 'UNET' +
                                      self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    model.load_weights(self.copy_model_dir + 'UNET' +
                                       self.copy_model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_last.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_best.h5')

            history = model.train(X, Y, validation_data=(X_val, Y_val))

            print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

        if self.TrainSTAR:
            print('Training StarDistModel model with', self.backbone,
                  'backbone')
            self.axis_norm = (0, 1, 2)
            if self.CroppedLoad == False:
                assert len(Raw) > 1, "not enough training data"
                print(len(Raw))
                rng = np.random.RandomState(42)
                ind = rng.permutation(len(Raw))

                X_train = list(map(ReadFloat, Raw))
                Y_train = list(map(ReadInt, RealMask))
                self.Y = [
                    label(DownsampleData(y, self.DownsampleFactor))
                    for y in tqdm(Y_train)
                ]
                self.X = [
                    normalize(DownsampleData(x, self.DownsampleFactor),
                              1,
                              99.8,
                              axis=self.axis_norm) for x in tqdm(X_train)
                ]
                n_val = max(1, int(round(0.15 * len(ind))))
                ind_train, ind_val = ind[:-n_val], ind[-n_val:]

                self.X_val, self.Y_val = [self.X[i] for i in ind_val
                                          ], [self.Y[i] for i in ind_val]
                self.X_trn, self.Y_trn = [self.X[i] for i in ind_train
                                          ], [self.Y[i] for i in ind_train]

                print('number of images: %3d' % len(self.X))
                print('- training:       %3d' % len(self.X_trn))
                print('- validation:     %3d' % len(self.X_val))

            if self.CroppedLoad:
                self.X_trn = self.DataSequencer(Raw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_trn = self.DataSequencer(RealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)

                self.X_val = self.DataSequencer(ValRaw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_val = self.DataSequencer(ValRealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)
                self.train_sample_cache = False

            print(Config3D.__doc__)

            anisotropy = (1, 1, 1)
            rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy)

            if self.backbone == 'resnet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    resnet_n_blocks=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    resnet_kernel_size=(self.kern_size, self.kern_size,
                                        self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    resnet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1)

            if self.backbone == 'unet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    unet_n_depth=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    unet_kernel_size=(self.kern_size, self.kern_size,
                                      self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    unet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1,
                    train_sample_cache=False)

            print(conf)
            vars(conf)

            Starmodel = StarDist3D(conf,
                                   name=self.model_name,
                                   basedir=self.model_dir)
            print(
                Starmodel._axes_tile_overlap('ZYX'),
                os.path.exists(self.model_dir + self.model_name + '/' +
                               'weights_now.h5'))

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_now.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_now.h5')
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_last.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_last.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_last.h5')

                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_best.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_best.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_best.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_last.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_best.h5')

            historyStar = Starmodel.train(self.X_trn,
                                          self.Y_trn,
                                          validation_data=(self.X_val,
                                                           self.Y_val),
                                          epochs=self.epochs)
            print(sorted(list(historyStar.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(historyStar, ['loss', 'val_loss'], [
                'dist_relevant_mae', 'val_dist_relevant_mae',
                'dist_relevant_mse', 'val_dist_relevant_mse'
            ])
# print('image size         =', x.shape)
# print('Z subsample factor =', subsample)
# plt.figure(figsize=(389/100, 389/100))
# plt.imshow(x, cmap='gray')
# plt.show()
# print('image size         =', x.shape)
# print('Z subsample factor =', subsample)

# raw_data = RawData.from_folder (
#     basepath    = 'data',
#     source_dirs = ['simulator_data'],
#     target_dir  = 'simulator_data',
#     axes        = 'CYX',
# )

raw_data = RawData.from_arrays(X=[x], Y=[x], axes='CYX')

anisotropic_transform = anisotropic_distortions(
    subsample=1,
    psf=np.ones((3)) / 9,  # use the actual PSF here
    psf_axes='Y',
    # poisson_noise = True,
    # gauss_sigma = 0.1
)

X, Y, XY_axes = create_patches(
    raw_data=raw_data,
    patch_size=(x.shape[0], x.shape[1], x.shape[2]),
    n_patches_per_image=1,
    transforms=[anisotropic_transform],
)
Exemplo n.º 12
0
def get_dataset(pd_scribbles,n_patches_per_image_train=30,n_patches_per_image_val=8,patch_size=(128, 128),
                p_label = 0.6,val_perc = 0.3,verbose = True, border = False):

    X_train = None
    X_val = None
    for i in range(len(pd_scribbles)):

        ## read image and label
        npz_read = np.load(pd_scribbles['input_dir'][i] + pd_scribbles['input_file'][i])
        image = npz_read['image']
        label = npz_read['label']
        nuclei = np.zeros_like(label)
        nuclei[label > 0] = 1

        ## read scribbles
        npz_read = np.load(pd_scribbles['input_dir'][i] + pd_scribbles['scribble_file'][i])
        scribble = npz_read['scribble']

        raw_image_in = image + 0  # normalize(image,pmin=pmin,pmax=pmax,clip = False)

        ## Sample validation mask
        patch_val_size = [int(image.shape[0] * val_perc),
                          int(image.shape[1] * val_perc)]
        all_back = True
        while all_back:

            val_mask = np.zeros([raw_image_in.shape[0], raw_image_in.shape[1]])
            ix_x = np.random.randint(0, raw_image_in.shape[0] - patch_val_size[0])
            ix_y = np.random.randint(0, raw_image_in.shape[0] - patch_val_size[1])

            val_mask[ix_x:ix_x + patch_val_size[0], ix_y:ix_y + patch_val_size[1]] = 1

            if np.sum(val_mask * np.sum(scribble[...], axis=-1)) > 10:
                all_back = False

        ## Generate patches
        raw_data = RawData.from_arrays(raw_image_in[np.newaxis, ...], scribble[np.newaxis, ...])

        ## for plot ##
        if verbose:
            aux = np.zeros([raw_image_in.shape[0], raw_image_in.shape[1], 3])
            if len(raw_image_in.shape)>2:
                aux[..., 1] = np.sum(raw_image_in,axis=-1) * 0.8
            else:
                aux[..., 1] = raw_image_in * 0.8
            aux[..., 0] = scribble[..., 0]
            aux[..., 2] = np.sum(scribble[..., 1:], axis=2)
        ###

        for group in ['val', 'train']:
            if group == 'val':
                fov_mask = np.array(val_mask)
                n_patches_per_image = n_patches_per_image_val + 0
                if verbose:
                    plt.figure(figsize=(10, 5))
                    plt.subplot(1, 2, 1)
                    plt.title('Validation FOV')
                    plt.imshow(fov_mask[..., np.newaxis] * aux)

            else:
                fov_mask = 1 - np.array(val_mask)
                n_patches_per_image = n_patches_per_image_train + 0
                if verbose:
                    plt.subplot(1, 2, 2)
                    plt.title('Train FOV')
                    plt.imshow(fov_mask[..., np.newaxis] * aux)
                    plt.show()

            X_aux, Y_aux, axes = generate_patches_syxc(raw_data, patch_size,
                                                       int(n_patches_per_image * (1 - p_label)),
                                                       normalization=None, patch_filter=None,
                                                       fov_mask=fov_mask)

            n_patches_add = int(n_patches_per_image - X_aux.shape[0])

            if n_patches_add > 0:
                X_labeled_aux, Y_labeled_aux, axes = generate_patches_syxc(raw_data, patch_size,
                                                                           n_patches_add,
                                                                           normalization=None,
                                                                           mask_filter_index=np.arange(
                                                                               scribble.shape[-1]),
                                                                           fov_mask=fov_mask)
                if X_labeled_aux is not None:
                    X_aux = np.concatenate([X_aux, X_labeled_aux], axis=0)
                    Y_aux = np.concatenate([Y_aux, Y_labeled_aux], axis=0)

            if group == 'val':
                if X_val is None:
                    X_val = np.array(X_aux)
                    Y_val = np.array(Y_aux)
                else:
                    X_val = np.concatenate([X_val, X_aux], axis=0)
                    Y_val = np.concatenate([Y_val, Y_aux], axis=0)

            else:
                if X_train is None:
                    X_train = np.array(X_aux)
                    Y_train = np.array(Y_aux)
                else:
                    X_train = np.concatenate([X_train, X_aux], axis=0)
                    Y_train = np.concatenate([Y_train, Y_aux], axis=0)

    print(Y_train.shape,Y_val.shape)
    if border:
        return X_train,Y_train,X_val,Y_val
    else:
        out_channels = int(Y_train.shape[-1]/3)
        Y_train_aux = np.zeros([Y_train.shape[0], Y_train.shape[1], Y_train.shape[2], out_channels * 2])
        Y_val_aux = np.zeros([Y_val.shape[0], Y_val.shape[1], Y_val.shape[2], out_channels * 2])
        # print(out_channels,Y_train.shape[2])
        for j in np.arange(out_channels):
            # print(j*2,j*out_channels)
            Y_train_aux[..., 2*j] = np.array(Y_train[..., out_channels*j])  # foreground
            Y_train_aux[..., 2*j+1] = Y_train[..., out_channels*j+1] + Y_train[..., out_channels*j+2]  # Border + background are background

            Y_val_aux[..., 2*j] = np.array(Y_val[..., out_channels*j])  # foreground
            Y_val_aux[..., 2*j+1] = Y_val[..., out_channels*j+1] + Y_val[..., out_channels*j+2]  # Border + background are background
        return X_train,Y_train_aux,X_val,Y_val_aux
Exemplo n.º 13
0
from argparse import ArgumentParser
from csbdeep.data import RawData, create_patches

args = ArgumentParser()
args.add_argument("input_basepath")
args.add_argument("input_x")
args.add_argument("input_y")
args.add_argument("output_file")
args.add_argument("--patch_size_xy", default=32)
args.add_argument("--patch_size_z", default=16)
args.add_argument("--n_patches_per_image", default=750)
args.add_argument("--axes", default="ZYX")
args = args.parse_args()

data = RawData.from_folder(basepath=args.input_basepath,
                           source_dirs=[args.input_x],
                           target_dir=args.input_y,
                           axes=args.axes)

ps_xy = int(args.patch_size_xy)
ps_z = int(args.patch_size_z)

X, Y, axes = create_patches(data,
                            patch_size=(ps_z, ps_xy, ps_xy, 1),
                            n_patches_per_image=int(args.n_patches_per_image),
                            save_file=args.output_file,
                            patch_axes=args.axes + "C")
Exemplo n.º 14
0
def original():
    x = np.load('data/img006_noconv.npy')
    x = np.moveaxis(x, 1, 2)
    ## initial axes are TZCYX, but we
    axes = 'CZYX'

    mypath = Path('isonet_psf_1')
    mypath.mkdir(exist_ok=True)

    subsample = 5.0  #10.2
    print('image size       =', x.shape)
    print('image axes       =', axes)
    print('subsample factor =', subsample)

    plt.switch_backend('agg')

    if False:
        plt.figure(figsize=(15, 15))
        plot_some(np.moveaxis(x[0], 1, -1)[[5, -5]],
                  title_list=[['xy slice', 'xy slice']],
                  pmin=2,
                  pmax=99.8)
        plt.savefig(mypath / 'datagen_1.png')

        plt.figure(figsize=(15, 15))
        plot_some(np.moveaxis(np.moveaxis(x[0], 1, -1)[:, [50, -50]], 1, 0),
                  title_list=[['xz slice', 'xz slice']],
                  pmin=2,
                  pmax=99.8,
                  aspect=subsample)
        plt.savefig(mypath / 'datagen_2.png')

    def gimmeit_gen():
        ## iterate over time dimension
        for i in range(x.shape[0]):
            yield x[i], x[i], axes, None

    raw_data = RawData(gimmeit_gen, x.shape[0], "this is great!")

    ## initial idea
    if False:

        def buildkernel():
            kern = np.exp(-(np.arange(10)**2 / 2))
            kern /= kern.sum()
            kern = kern.reshape([1, 1, -1, 1])
            kern = np.stack([kern, kern], axis=1)
            return kern

        psf_kern = buildkernel()

    ## use Martins theoretical psf
    if False:
        psf_aniso = imread('data/psf_aniso_NA_0.8.tif')
        psf_channels = np.stack([
            psf_aniso,
        ] * 2, axis=1)

    def buildkernel():
        kernel = np.zeros(20)
        kernel[7:13] = 1 / 6
        ## reshape into CZYX. long axis is X.
        kernel = kernel.reshape([1, 1, -1])
        ## repeate same kernel for both channels
        kernel = np.stack([kernel, kernel], axis=0)
        return kernel

    psf = buildkernel()
    print(psf.shape)

    ## use theoretical psf
    if False:
        psf_channels = np.load('data/measured_psfs.npy')
        psf = rotate(psf_channels, 90, axes=(1, 3))

    iso_transform = data.anisotropic_distortions(
        subsample=subsample,
        psf=psf,
    )

    X, Y, XY_axes = data.create_patches(
        raw_data=raw_data,
        patch_size=(2, 1, 128, 128),
        n_patches_per_image=256,
        transforms=[iso_transform],
    )

    assert X.shape == Y.shape
    print("shape of X,Y =", X.shape)
    print("axes  of X,Y =", XY_axes)

    # remove dummy z dim to obtain multi-channel 2D patches
    X = X[:, :, 0, ...]
    Y = Y[:, :, 0, ...]
    XY_axes = XY_axes.replace('Z', '')

    assert X.shape == Y.shape
    print("shape of X,Y =", X.shape)
    print("axes  of X,Y =", XY_axes)

    np.savez(mypath / 'my_training_data.npz', X=X, Y=Y, axes=XY_axes)

    for i in range(2):
        plt.figure(figsize=(16, 4))
        sl = slice(8 * i, 8 * (i + 1))
        plot_some(np.moveaxis(X[sl], 1, -1),
                  np.moveaxis(Y[sl], 1, -1),
                  title_list=[np.arange(sl.start, sl.stop)])
        plt.savefig(mypath / 'datagen_panel_{}.png'.format(i))
        y = imread(
            os.path.join(args.base_dir, "high_snr", f"img_{img_number}.tif"))
        print(f"image size: {x.shape}")
        plt.figure(figsize=(16, 10))
        plot_some(
            np.stack([x, y]),
            title_list=[['low snr', 'high snr']],
        )
        plt.show()

    if not os.path.exists(os.path.join(args.base_dir, "test")):
        os.mkdir(os.path.join(args.base_dir, "test"))
    shutil.move(
        os.path.join(args.base_dir, "low_snr", f"img_{img_number}.tif"),
        os.path.join(args.base_dir, "test", "low_snr.tif"))
    shutil.move(
        os.path.join(args.base_dir, "high_snr", f"img_{img_number}.tif"),
        os.path.join(args.base_dir, "test", "high_snr.tif"))

    # Read the pairs, passing in the axis semantics
    raw_data = RawData.from_folder(basepath=args.base_dir,
                                   source_dirs=['low_snr'],
                                   target_dir='high_snr',
                                   axes='YX')

    # From the stacks, generate 2D patches, and save to the output_filename
    create_patches(raw_data=raw_data,
                   patch_size=(128, 128),
                   n_patches_per_image=512,
                   save_file=os.path.join(args.base_dir, args.output_filename))
Exemplo n.º 16
0
    def training(self):
        #Loading Files and Plot examples
        basepath = 'data/'
        training_original_dir = 'training/original/'
        training_ground_truth_dir = 'training/ground_truth/'
        validation_original_dir = 'validation/original/'
        validation_ground_truth_dir = 'validation/ground_truth/'
        import glob
        from skimage import io
        from matplotlib import pyplot as plt
        training_original_files = sorted(
            glob.glob(basepath + training_original_dir + '*.tif'))
        training_original_file = io.imread(training_original_files[0])
        training_ground_truth_files = sorted(
            glob.glob(basepath + training_ground_truth_dir + '*.tif'))
        training_ground_truth_file = io.imread(training_ground_truth_files[0])
        print("Training dataset's number of files and dimensions: ",
              len(training_original_files), training_original_file.shape,
              len(training_ground_truth_files),
              training_ground_truth_file.shape)
        training_size = len(training_original_file)

        validation_original_files = sorted(
            glob.glob(basepath + validation_original_dir + '*.tif'))
        validation_original_file = io.imread(validation_original_files[0])
        validation_ground_truth_files = sorted(
            glob.glob(basepath + validation_ground_truth_dir + '*.tif'))
        validation_ground_truth_file = io.imread(
            validation_ground_truth_files[0])
        print("Validation dataset's number of files and dimensions: ",
              len(validation_original_files), validation_original_file.shape,
              len(validation_ground_truth_files),
              validation_ground_truth_file.shape)
        validation_size = len(validation_original_file)

        if training_size == validation_size:
            size = training_size
        else:
            print(
                'Training and validation images should be of the same dimensions!'
            )

        plt.figure(figsize=(16, 4))
        plt.subplot(141)
        plt.imshow(training_original_file)
        plt.subplot(142)
        plt.imshow(training_ground_truth_file)
        plt.subplot(143)
        plt.imshow(validation_original_file)
        plt.subplot(144)
        plt.imshow(validation_ground_truth_file)

        #preparing inputs for NN  from pairs of 32bit TIFF image files with intensities in range 0..1. . Run it only once for each new dataset
        from csbdeep.data import RawData, create_patches, no_background_patches
        training_data = RawData.from_folder(
            basepath=basepath,
            source_dirs=[training_original_dir],
            target_dir=training_ground_truth_dir,
            axes='YX',
        )

        validation_data = RawData.from_folder(
            basepath=basepath,
            source_dirs=[validation_original_dir],
            target_dir=validation_ground_truth_dir,
            axes='YX',
        )

        # pathces will be created further in "data augmentation" step,
        # that's why patch size here is the dimensions of images and number of pathes per image is 1
        size1 = 64
        X, Y, XY_axes = create_patches(
            raw_data=training_data,
            patch_size=(size1, size1),
            patch_filter=no_background_patches(0),
            n_patches_per_image=1,
            save_file=basepath + 'training.npz',
        )

        X_val, Y_val, XY_axes = create_patches(
            raw_data=validation_data,
            patch_size=(size1, size1),
            patch_filter=no_background_patches(0),
            n_patches_per_image=1,
            save_file=basepath + 'validation.npz',
        )

        #Loading training and validation data into memory
        from csbdeep.io import load_training_data
        (X, Y), _, axes = load_training_data(basepath + 'training.npz',
                                             verbose=False)
        (X_val,
         Y_val), _, axes = load_training_data(basepath + 'validation.npz',
                                              verbose=False)
        X.shape, Y.shape, X_val.shape, Y_val.shape
        from csbdeep.utils import axes_dict
        c = axes_dict(axes)['C']
        n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

        batch = len(
            X
        )  # You should define number of batches according to the available memory
        #batch=1
        seed = 1
        from keras.preprocessing.image import ImageDataGenerator
        data_gen_args = dict(samplewise_center=False,
                             samplewise_std_normalization=False,
                             zca_whitening=False,
                             fill_mode='reflect',
                             rotation_range=30,
                             shear_range=0.2,
                             zoom_range=0.2,
                             horizontal_flip=True,
                             vertical_flip=True
                             #width_shift_range=0.2,
                             #height_shift_range=0.2,
                             )

        # training
        image_datagen = ImageDataGenerator(**data_gen_args)
        mask_datagen = ImageDataGenerator(**data_gen_args)
        image_datagen.fit(X, augment=True, seed=seed)
        mask_datagen.fit(Y, augment=True, seed=seed)
        image_generator = image_datagen.flow(X, batch_size=batch, seed=seed)
        mask_generator = mask_datagen.flow(Y, batch_size=batch, seed=seed)
        generator = zip(image_generator, mask_generator)

        # validation
        image_datagen_val = ImageDataGenerator(**data_gen_args)
        mask_datagen_val = ImageDataGenerator(**data_gen_args)
        image_datagen_val.fit(X_val, augment=True, seed=seed)
        mask_datagen_val.fit(Y_val, augment=True, seed=seed)
        image_generator_val = image_datagen_val.flow(X_val,
                                                     batch_size=batch,
                                                     seed=seed)
        mask_generator_val = mask_datagen_val.flow(Y_val,
                                                   batch_size=batch,
                                                   seed=seed)
        generator_val = zip(image_generator_val, mask_generator_val)

        # plot examples
        x, y = generator.__next__()
        x_val, y_val = generator_val.__next__()

        plt.figure(figsize=(16, 4))
        plt.subplot(141)
        plt.imshow(x[0, :, :, 0])
        plt.subplot(142)
        plt.imshow(y[0, :, :, 0])
        plt.subplot(143)
        plt.imshow(x_val[0, :, :, 0])
        plt.subplot(144)
        plt.imshow(y_val[0, :, :, 0])

        import os
        blocks = 2
        channels = 16
        learning_rate = 0.0004
        learning_rate_decay_factor = 0.95
        epoch_size_multiplicator = 20  # basically, how often do we decrease learning rate
        epochs = 10
        #comment='_side' # adds to model_name
        import datetime
        #model path
        model_path = f'models/CSBDeep/model.h5'
        if os.path.isfile(model_path):
            print('Your model will be overwritten in the next cell')
        kernel_size = 3

        from csbdeep.models import Config, CARE
        from keras import backend as K
        best_mae = 1
        steps_per_epoch = len(X) * epoch_size_multiplicator
        validation_steps = len(X_val) * epoch_size_multiplicator
        if 'model' in globals():
            del model
        if os.path.isfile(model_path):
            os.remove(model_path)
        for i in range(epochs):
            print('Epoch:', i + 1)
            learning_rate = learning_rate * learning_rate_decay_factor
            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_kern_size=kernel_size,
                            train_learning_rate=learning_rate,
                            unet_n_depth=blocks,
                            unet_n_first=channels)
            model = CARE(config, '.', basedir='models')
            #os.remove('models/config.json')
            if i > 0:
                model.keras_model.load_weights(model_path)
            model.prepare_for_training()
            history = model.keras_model.fit_generator(
                generator,
                validation_data=generator_val,
                validation_steps=validation_steps,
                epochs=1,
                verbose=0,
                shuffle=True,
                steps_per_epoch=steps_per_epoch)
            if history.history['val_mae'][0] < best_mae:
                best_mae = history.history['val_mae'][0]
                if not os.path.exists('models/'):
                    os.makedirs('models/')
                model.keras_model.save(model_path)
                print(f'Validation MAE:{best_mae:.3E}')
            del model
            K.clear_session()

        #model path
        model_path = 'models/CSBDeep/model.h5'
        config = Config(axes,
                        n_channel_in,
                        n_channel_out,
                        unet_kern_size=kernel_size,
                        train_learning_rate=learning_rate,
                        unet_n_depth=blocks,
                        unet_n_first=channels)
        model = CARE(config, '.', basedir='models')
        model.keras_model.load_weights(model_path)
        model.export_TF()