예제 #1
0
    def Train(self):

        BinaryName = 'BinaryMask/'
        RealName = 'RealMask/'
        Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif'))
        Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True)
        Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True)
        RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif'))
        ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif'))
        ValRealMask = sorted(
            glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif'))

        print('Instance segmentation masks:', len(RealMask))
        if len(RealMask) == 0:

            print('Making labels')
            Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))

            for fname in Mask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = label(image)

                imwrite((self.BaseDir + '/' + RealName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))
        print('Semantic segmentation masks:', len(Mask))
        if len(Mask) == 0:
            print('Generating Binary images')

            RealfilesMask = sorted(
                glob.glob(self.BaseDir + '/' + RealName + '*tif'))

            for fname in RealfilesMask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = image > 0

                imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        if self.GenerateNPZ:

            raw_data = RawData.from_folder(
                basepath=self.BaseDir,
                source_dirs=['Raw/'],
                target_dir='BinaryMask/',
                axes='ZYX',
            )

            X, Y, XY_axes = create_patches(
                raw_data=raw_data,
                patch_size=(self.PatchZ, self.PatchY, self.PatchX),
                n_patches_per_image=self.n_patches_per_image,
                save_file=self.BaseDir + self.NPZfilename + '.npz',
            )

        # Training UNET model
        if self.TrainUNET:
            print('Training UNET model')
            load_path = self.BaseDir + self.NPZfilename + '.npz'

            (X, Y), (X_val,
                     Y_val), axes = load_training_data(load_path,
                                                       validation_split=0.1,
                                                       verbose=True)
            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_n_depth=self.depth,
                            train_epochs=self.epochs,
                            train_batch_size=self.batch_size,
                            unet_n_first=self.startfilter,
                            train_loss='mse',
                            unet_kern_size=self.kern_size,
                            train_learning_rate=self.learning_rate,
                            train_reduce_lr={
                                'patience': 5,
                                'factor': 0.5
                            })
            print(config)
            vars(config)

            model = CARE(config,
                         name='UNET' + self.model_name,
                         basedir=self.model_dir)

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + 'UNET' +
                                  self.copy_model_name + '/' +
                                  'weights_now.h5') and os.path.exists(
                                      self.model_dir + 'UNET' +
                                      self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    model.load_weights(self.copy_model_dir + 'UNET' +
                                       self.copy_model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_last.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_best.h5')

            history = model.train(X, Y, validation_data=(X_val, Y_val))

            print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

        if self.TrainSTAR:
            print('Training StarDistModel model with', self.backbone,
                  'backbone')
            self.axis_norm = (0, 1, 2)
            if self.CroppedLoad == False:
                assert len(Raw) > 1, "not enough training data"
                print(len(Raw))
                rng = np.random.RandomState(42)
                ind = rng.permutation(len(Raw))

                X_train = list(map(ReadFloat, Raw))
                Y_train = list(map(ReadInt, RealMask))
                self.Y = [
                    label(DownsampleData(y, self.DownsampleFactor))
                    for y in tqdm(Y_train)
                ]
                self.X = [
                    normalize(DownsampleData(x, self.DownsampleFactor),
                              1,
                              99.8,
                              axis=self.axis_norm) for x in tqdm(X_train)
                ]
                n_val = max(1, int(round(0.15 * len(ind))))
                ind_train, ind_val = ind[:-n_val], ind[-n_val:]

                self.X_val, self.Y_val = [self.X[i] for i in ind_val
                                          ], [self.Y[i] for i in ind_val]
                self.X_trn, self.Y_trn = [self.X[i] for i in ind_train
                                          ], [self.Y[i] for i in ind_train]

                print('number of images: %3d' % len(self.X))
                print('- training:       %3d' % len(self.X_trn))
                print('- validation:     %3d' % len(self.X_val))

            if self.CroppedLoad:
                self.X_trn = self.DataSequencer(Raw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_trn = self.DataSequencer(RealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)

                self.X_val = self.DataSequencer(ValRaw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_val = self.DataSequencer(ValRealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)
                self.train_sample_cache = False

            print(Config3D.__doc__)

            anisotropy = (1, 1, 1)
            rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy)

            if self.backbone == 'resnet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    resnet_n_blocks=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    resnet_kernel_size=(self.kern_size, self.kern_size,
                                        self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    resnet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1)

            if self.backbone == 'unet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    unet_n_depth=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    unet_kernel_size=(self.kern_size, self.kern_size,
                                      self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    unet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1,
                    train_sample_cache=False)

            print(conf)
            vars(conf)

            Starmodel = StarDist3D(conf,
                                   name=self.model_name,
                                   basedir=self.model_dir)
            print(
                Starmodel._axes_tile_overlap('ZYX'),
                os.path.exists(self.model_dir + self.model_name + '/' +
                               'weights_now.h5'))

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_now.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_now.h5')
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_last.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_last.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_last.h5')

                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_best.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_best.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_best.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_last.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_best.h5')

            historyStar = Starmodel.train(self.X_trn,
                                          self.Y_trn,
                                          validation_data=(self.X_val,
                                                           self.Y_val),
                                          epochs=self.epochs)
            print(sorted(list(historyStar.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(historyStar, ['loss', 'val_loss'], [
                'dist_relevant_mae', 'val_dist_relevant_mae',
                'dist_relevant_mse', 'val_dist_relevant_mse'
            ])
예제 #2
0
def main():
    if not ('__file__' in locals() or '__file__' in globals()):
        print('running interactively, exiting.')
        sys.exit(0)

    # parse arguments
    parser, args = parse_args()
    args_dict = vars(args)

    # exit and show help if no arguments provided at all
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(0)

    # check for required arguments manually (because of argparse issue)
    required = ('--input-dir', '--input-axes', '--norm-pmin', '--norm-pmax',
                '--model-basedir', '--model-name', '--output-dir')
    for r in required:
        dest = r[2:].replace('-', '_')
        if args_dict[dest] is None:
            parser.print_usage(file=sys.stderr)
            print("%s: error: the following arguments are required: %s" %
                  (parser.prog, r),
                  file=sys.stderr)
            sys.exit(1)

    # show effective arguments (including defaults)
    if not args.quiet:
        print('Arguments')
        print('---------')
        pprint(args_dict)
        print()
        sys.stdout.flush()

    # logging function
    log = (lambda *a, **k: None) if args.quiet else tqdm.write

    # get list of input files and exit if there are none
    file_list = list(Path(args.input_dir).glob(args.input_pattern))
    if len(file_list) == 0:
        log("No files to process in '%s' with pattern '%s'." %
            (args.input_dir, args.input_pattern))
        sys.exit(0)

    # delay imports after checking to all required arguments are provided
    from tifffile import imread, imsave
    from csbdeep.utils.tf import keras_import
    K = keras_import('backend')
    from csbdeep.models import CARE
    from csbdeep.data import PercentileNormalizer
    sys.stdout.flush()
    sys.stderr.flush()

    # limit gpu memory
    if args.gpu_memory_limit is not None:
        from csbdeep.utils.tf import limit_gpu_memory
        limit_gpu_memory(args.gpu_memory_limit)

    # create CARE model and load weights, create normalizer
    K.clear_session()
    model = CARE(config=None, name=args.model_name, basedir=args.model_basedir)
    if args.model_weights is not None:
        print("Loading network weights from '%s'." % args.model_weights)
        model.load_weights(args.model_weights)
    normalizer = PercentileNormalizer(pmin=args.norm_pmin,
                                      pmax=args.norm_pmax,
                                      do_after=args.norm_undo)

    n_tiles = args.n_tiles
    if n_tiles is not None and len(n_tiles) == 1:
        n_tiles = n_tiles[0]

    processed = []

    # process all files
    for file_in in tqdm(file_list,
                        disable=args.quiet
                        or (n_tiles is not None and np.prod(n_tiles) > 1)):
        # construct output file name
        file_out = Path(args.output_dir) / args.output_name.format(
            file_path=str(file_in.relative_to(args.input_dir).parent),
            file_name=file_in.stem,
            file_ext=file_in.suffix,
            model_name=args.model_name,
            model_weights=Path(args.model_weights).stem
            if args.model_weights is not None else None)

        # checks
        (file_in.suffix.lower() in ('.tif', '.tiff')
         and file_out.suffix.lower() in ('.tif', '.tiff')) or _raise(
             ValueError('only tiff files supported.'))

        # load and predict restored image
        img = imread(str(file_in))
        restored = model.predict(img,
                                 axes=args.input_axes,
                                 normalizer=normalizer,
                                 n_tiles=n_tiles)

        # restored image could be multi-channel even if input image is not
        axes_out = axes_check_and_normalize(args.input_axes)
        if restored.ndim > img.ndim:
            assert restored.ndim == img.ndim + 1
            assert 'C' not in axes_out
            axes_out += 'C'

        # convert data type (if necessary)
        restored = restored.astype(np.dtype(args.output_dtype), copy=False)

        # save to disk
        if not args.dry_run:
            file_out.parent.mkdir(parents=True, exist_ok=True)
            if args.imagej_tiff:
                save_tiff_imagej_compatible(str(file_out), restored, axes_out)
            else:
                imsave(str(file_out), restored)

        processed.append((file_in, file_out))

    # print summary of processed files
    if not args.quiet:
        sys.stdout.flush()
        sys.stderr.flush()
        n_processed = len(processed)
        len_processed = len(str(n_processed))
        log('Finished processing %d %s' %
            (n_processed, 'files' if n_processed > 1 else 'file'))
        log('-' * (26 + len_processed if n_processed > 1 else 26))
        for i, (file_in, file_out) in enumerate(processed):
            len_file = max(len(str(file_in)), len(str(file_out)))
            log(('{:>%d}. in : {:>%d}' % (len_processed, len_file)).format(
                1 + i, str(file_in)))
            log(('{:>%d}  out: {:>%d}' % (len_processed, len_file)).format(
                '', str(file_out)))
예제 #3
0
파일: n2v_flim.py 프로젝트: seanwarren/n2v
def n2v_flim(project, n2v_num_pix=32):
   
   results_file = os.path.join(project, 'fit_results.hdf5')

   X, groups, mask = extract_results(results_file)
   data_shape = np.shape(X)
   print(data_shape)

   mean, std = np.mean(X), np.std(X)
   X = normalize(X, mean, std)

   XA = X #augment_data(X)

   X_val = X[0:10,...]

   # We concatenate an extra channel filled with zeros. It will be internally used for the masking.
   Y = np.concatenate((XA, np.zeros(XA.shape)), axis=-1)
   Y_val = np.concatenate((X_val.copy(), np.zeros(X_val.shape)), axis=-1) 

   n_x = X.shape[1]
   n_chan = X.shape[-1]

   manipulate_val_data(X_val, Y_val, num_pix=n_x*n_x*2/n2v_num_pix , shape=(n_x, n_x))


   # You can increase "train_steps_per_epoch" to get even better results at the price of longer computation. 
   config = Config('SYXC', 
                  n_channel_in=n_chan, 
                  n_channel_out=n_chan, 
                  unet_kern_size = 5, 
                  unet_n_depth = 2,
                  train_steps_per_epoch=200, 
                  train_loss='mae',
                  train_epochs=35,
                  batch_norm = False, 
                  train_scheme = 'Noise2Void', 
                  train_batch_size = 128, 
                  n2v_num_pix = n2v_num_pix,
                  n2v_patch_shape = (n2v_num_pix, n2v_num_pix), 
                  n2v_manipulator = 'uniform_withCP', 
                  n2v_neighborhood_radius='5')

   vars(config)

   model = CARE(config, 'n2v_model', basedir=project)

   history = model.train(XA, Y, validation_data=(X_val,Y_val))

   model.load_weights(name='weights_best.h5')

   output_project = project.replace('.flimfit','-n2v.flimfit')
   if os.path.exists(output_project) : shutil.rmtree(output_project)
   shutil.copytree(project, output_project)

   output_file = os.path.join(output_project, 'fit_results.hdf5')

   X_pred = np.zeros(X.shape)
   for i in range(X.shape[0]):
      X_pred[i,...] = denormalize(model.predict(X[i], axes='YXC',normalizer=None), mean, std)

   X_pred[mask] = np.NaN

   insert_results(output_file, X_pred, groups)