Exemplo n.º 1
0
Raw_path = os.path.join(basedir, '*TIF') #tif or TIF be careful

axes = 'ZYX'  #projection axes : 'YX'

filesRaw = glob.glob(Raw_path)


# In[6]:


for fname in filesRaw:
       if  os.path.exists(fname) == True :
            if  os.path.exists(basedirResults2Dextended + '_' + os.path.basename(fname)) == False :
                print(fname)
                y = imread(fname)
                restored = RestorationModel.predict(y, axes, n_tiles = (1,4,4)) #n_tiles is for the decomposition of the image in (z,y,x). (1,2,2) will work with light images. Less tiles we have, faster the calculation is 
                projection = ProjectionModel.predict(restored, axes, n_tiles = (1,1,2)) #n_tiles is for the decomposition of the image in (z,y,x). There is overlapping in the decomposition wich is managed by the program itself
                axes_restored = axes.replace(ProjectionModel.proj_params.axis, '')
                restored = restored.astype('uint8') # if prediction and projection running at the same time
                #restored = restored.astype('uint16') # if projection training set creation or waiting for a future projection 
                projection = projection.astype('uint8')
                #save_tiff_imagej_compatible((basedirResults3Dextended  + os.path.basename(fname)) , restored, axes)
                save_tiff_imagej_compatible((basedirResults2Dextended + '_' + os.path.basename(fname)) , projection, axes_restored)


#Normal Images :                 restored = RestorationModel.predict(y, axes, n_tiles = (1,2,4)) 
#                                projection = ProjectionModel.predict(restored, axes, n_tiles = (1,1,1)) 
                
# In[]:

Exemplo n.º 2
0
    def predict(self, file_fn, n_tiles=(1, 4, 4), keep_meta=True):
        JVM().start()

        pixel_reso = get_space_time_resolution(file_fn)
        print("Prediction {}".format(file_fn))
        print(" -- Using pixel sizes and frame interval", pixel_reso)

        ir = bf.ImageReader(file_fn)
        reader = ir.rdr

        loci_pixel_type = reader.getPixelType()

        if loci_pixel_type == 1:
            # uint8
            dtype = numpy.uint8
        elif loci_pixel_type == 3:
            # uint16
            dtype = numpy.uint16
        else:
            print(
                "Error: Pixel-type not supported. Pixel type must be 8- or 16-bit"
            )
            return

        series = 0
        z_size = reader.getSizeZ()
        y_size = reader.getSizeY()
        x_size = reader.getSizeX()
        c_size = reader.getSizeC()
        t_size = reader.getSizeT()

        z_out_size = int(z_size * self.low_scaling[0])
        y_out_size = int(y_size * self.low_scaling[1])
        x_out_size = int(x_size * self.low_scaling[2])

        if c_size != len(self.train_channels):
            print(
                " -- Warning: Number of Channels during training and prediction do not match. Using channels {} for prediction"
                .format(self.train_channels))

        for ch in self.train_channels:
            model = CARE(None,
                         'CH_{}_model'.format(ch),
                         basedir=pathlib.Path(self.out_dir) / 'models')
            res_image_ch = numpy.zeros(shape=(t_size, z_out_size, 1,
                                              y_out_size, x_out_size),
                                       dtype=dtype)
            print(" -- Predicting channel {}".format(ch))
            for t in tqdm(range(t_size), total=t_size):
                img_3d = numpy.zeros((z_size, y_size, x_size), dtype=dtype)
                for z in range(z_size):
                    img_3d[z, :, :] = ir.read(series=series,
                                              z=z,
                                              c=ch,
                                              t=t,
                                              rescale=False)

                img_3d_ch_ex = rescale(img_3d,
                                       self.low_scaling,
                                       preserve_range=True,
                                       order=self.order,
                                       multichannel=False,
                                       mode="reflect",
                                       anti_aliasing=True)

                pred = model.predict(img_3d_ch_ex, axes='ZYX', n_tiles=n_tiles)

                di = numpy.iinfo(dtype)
                pred = pred.clip(di.min, di.max).astype(dtype)

                res_image_ch[t, :, 0, :, :] = pred

                if False:
                    ch_t_out_fn = os.path.join(
                        os.path.dirname(file_fn),
                        os.path.splitext(os.path.basename(file_fn))[0] +
                        "_care_predict_tp{:04d}_ch{}.tif".format(t, ch))
                    print("Saving time-point {} and channel {} to file '{}'".
                          format(t, ch, ch_t_out_fn))
                    tifffile.imsave(ch_t_out_fn,
                                    pred[None, :, None, :, :],
                                    imagej=True,
                                    metadata={'axes': 'TZCYX'})

            ch_out_fn = os.path.join(
                os.path.dirname(file_fn),
                os.path.splitext(os.path.basename(file_fn))[0] +
                "_care_predict_ch{}.tif".format(ch))
            print(" -- Saving channel {} CARE prediction to file '{}'".format(
                ch, ch_out_fn))

            if keep_meta:
                reso = (1 / (pixel_reso.X / self.low_scaling[2]),
                        1 / (pixel_reso.Y / self.low_scaling[1]))
                spacing = pixel_reso.Z / self.low_scaling[0]
                unit = pixel_reso.Xunit
                finterval = pixel_reso.T

                tifffile.imsave(ch_out_fn,
                                res_image_ch,
                                imagej=True,
                                resolution=reso,
                                metadata={
                                    'axes': 'TZCYX',
                                    'finterval': finterval,
                                    'spacing': spacing,
                                    'unit': unit
                                })
            else:
                tifffile.imsave(ch_out_fn, res_image_ch)

            res_image_ch = None  # should trigger gc and free the memory
Exemplo n.º 3
0
Raw_path = os.path.join(basedir, '*TIF') #tif or TIF be careful

axes = 'ZYX'  #projection axes : 'YX'

filesRaw = glob.glob(Raw_path)


# In[6]:

for fname in filesRaw:
        if  os.path.exists(fname) == True :
            if  os.path.exists(basedirResults3Dextended + os.path.basename(fname)) == False or os.path.exists(basedirResults2Dextended + '_' + os.path.basename(fname)) == False :
                print(fname)
                y = imread(fname)
                restored = RestorationModel.predict(y, axes, n_tiles = (1,8,8))
                #restored = RestorationModel.predict(y, axes, n_tiles = (1,4,8)) #n_tiles is for the decomposition of the image in (z,y,x). (1,2,2) will work with light images. Less tiles we have, faster the calculation is 
                projection = ProjectionModel.predict(restored, axes, n_tiles = (1,4,4))
                #projection = ProjectionModel.predict(restored, axes, n_tiles = (1,1,2)) #n_tiles is for the decomposition of the image in (z,y,x). There is overlapping in the decomposition wich is managed by the program itself
                axes_restored = axes.replace(ProjectionModel.proj_params.axis, '')
                restored = restored.astype('uint8') # if prediction and projection running at the same time
                #restored = restored.astype('uint16') # if projection training set creation or waiting for a future projection 
                projection = projection.astype('uint8')
                #save_tiff_imagej_compatible((basedirResults3Dextended  + os.path.basename(fname)) , restored, axes)
                save_tiff_imagej_compatible((basedirResults2Dextended + '_' + os.path.basename(fname)) , projection, axes_restored)


# In[]:


from csbdeep.utils import Path
Exemplo n.º 4
0
def n2v_flim(project, n2v_num_pix=32):
   
   results_file = os.path.join(project, 'fit_results.hdf5')

   X, groups, mask = extract_results(results_file)
   data_shape = np.shape(X)
   print(data_shape)

   mean, std = np.mean(X), np.std(X)
   X = normalize(X, mean, std)

   XA = X #augment_data(X)

   X_val = X[0:10,...]

   # We concatenate an extra channel filled with zeros. It will be internally used for the masking.
   Y = np.concatenate((XA, np.zeros(XA.shape)), axis=-1)
   Y_val = np.concatenate((X_val.copy(), np.zeros(X_val.shape)), axis=-1) 

   n_x = X.shape[1]
   n_chan = X.shape[-1]

   manipulate_val_data(X_val, Y_val, num_pix=n_x*n_x*2/n2v_num_pix , shape=(n_x, n_x))


   # You can increase "train_steps_per_epoch" to get even better results at the price of longer computation. 
   config = Config('SYXC', 
                  n_channel_in=n_chan, 
                  n_channel_out=n_chan, 
                  unet_kern_size = 5, 
                  unet_n_depth = 2,
                  train_steps_per_epoch=200, 
                  train_loss='mae',
                  train_epochs=35,
                  batch_norm = False, 
                  train_scheme = 'Noise2Void', 
                  train_batch_size = 128, 
                  n2v_num_pix = n2v_num_pix,
                  n2v_patch_shape = (n2v_num_pix, n2v_num_pix), 
                  n2v_manipulator = 'uniform_withCP', 
                  n2v_neighborhood_radius='5')

   vars(config)

   model = CARE(config, 'n2v_model', basedir=project)

   history = model.train(XA, Y, validation_data=(X_val,Y_val))

   model.load_weights(name='weights_best.h5')

   output_project = project.replace('.flimfit','-n2v.flimfit')
   if os.path.exists(output_project) : shutil.rmtree(output_project)
   shutil.copytree(project, output_project)

   output_file = os.path.join(output_project, 'fit_results.hdf5')

   X_pred = np.zeros(X.shape)
   for i in range(X.shape[0]):
      X_pred[i,...] = denormalize(model.predict(X[i], axes='YXC',normalizer=None), mean, std)

   X_pred[mask] = np.NaN

   insert_results(output_file, X_pred, groups)