def Train(self): BinaryName = 'BinaryMask/' RealName = 'RealMask/' Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif')) Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True) Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True) RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif')) ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif')) ValRealMask = sorted( glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif')) print('Instance segmentation masks:', len(RealMask)) if len(RealMask) == 0: print('Making labels') Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) for fname in Mask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = label(image) imwrite((self.BaseDir + '/' + RealName + Name + '.tif'), Binaryimage.astype('uint16')) Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) print('Semantic segmentation masks:', len(Mask)) if len(Mask) == 0: print('Generating Binary images') RealfilesMask = sorted( glob.glob(self.BaseDir + '/' + RealName + '*tif')) for fname in RealfilesMask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = image > 0 imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'), Binaryimage.astype('uint16')) if self.GenerateNPZ: raw_data = RawData.from_folder( basepath=self.BaseDir, source_dirs=['Raw/'], target_dir='BinaryMask/', axes='ZYX', ) X, Y, XY_axes = create_patches( raw_data=raw_data, patch_size=(self.PatchZ, self.PatchY, self.PatchX), n_patches_per_image=self.n_patches_per_image, save_file=self.BaseDir + self.NPZfilename + '.npz', ) # Training UNET model if self.TrainUNET: print('Training UNET model') load_path = self.BaseDir + self.NPZfilename + '.npz' (X, Y), (X_val, Y_val), axes = load_training_data(load_path, validation_split=0.1, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] config = Config(axes, n_channel_in, n_channel_out, unet_n_depth=self.depth, train_epochs=self.epochs, train_batch_size=self.batch_size, unet_n_first=self.startfilter, train_loss='mse', unet_kern_size=self.kern_size, train_learning_rate=self.learning_rate, train_reduce_lr={ 'patience': 5, 'factor': 0.5 }) print(config) vars(config) model = CARE(config, name='UNET' + self.model_name, basedir=self.model_dir) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') model.load_weights(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5') history = model.train(X, Y, validation_data=(X_val, Y_val)) print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) if self.TrainSTAR: print('Training StarDistModel model with', self.backbone, 'backbone') self.axis_norm = (0, 1, 2) if self.CroppedLoad == False: assert len(Raw) > 1, "not enough training data" print(len(Raw)) rng = np.random.RandomState(42) ind = rng.permutation(len(Raw)) X_train = list(map(ReadFloat, Raw)) Y_train = list(map(ReadInt, RealMask)) self.Y = [ label(DownsampleData(y, self.DownsampleFactor)) for y in tqdm(Y_train) ] self.X = [ normalize(DownsampleData(x, self.DownsampleFactor), 1, 99.8, axis=self.axis_norm) for x in tqdm(X_train) ] n_val = max(1, int(round(0.15 * len(ind)))) ind_train, ind_val = ind[:-n_val], ind[-n_val:] self.X_val, self.Y_val = [self.X[i] for i in ind_val ], [self.Y[i] for i in ind_val] self.X_trn, self.Y_trn = [self.X[i] for i in ind_train ], [self.Y[i] for i in ind_train] print('number of images: %3d' % len(self.X)) print('- training: %3d' % len(self.X_trn)) print('- validation: %3d' % len(self.X_val)) if self.CroppedLoad: self.X_trn = self.DataSequencer(Raw, self.axis_norm, Normalize=True, labelMe=False) self.Y_trn = self.DataSequencer(RealMask, self.axis_norm, Normalize=False, labelMe=True) self.X_val = self.DataSequencer(ValRaw, self.axis_norm, Normalize=True, labelMe=False) self.Y_val = self.DataSequencer(ValRealMask, self.axis_norm, Normalize=False, labelMe=True) self.train_sample_cache = False print(Config3D.__doc__) anisotropy = (1, 1, 1) rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy) if self.backbone == 'resnet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, resnet_n_blocks=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', resnet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, resnet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1) if self.backbone == 'unet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, unet_n_depth=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', unet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, unet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1, train_sample_cache=False) print(conf) vars(conf) Starmodel = StarDist3D(conf, name=self.model_name, basedir=self.model_dir) print( Starmodel._axes_tile_overlap('ZYX'), os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5')) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_last.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_best.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_best.h5') historyStar = Starmodel.train(self.X_trn, self.Y_trn, validation_data=(self.X_val, self.Y_val), epochs=self.epochs) print(sorted(list(historyStar.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(historyStar, ['loss', 'val_loss'], [ 'dist_relevant_mae', 'val_dist_relevant_mae', 'dist_relevant_mse', 'val_dist_relevant_mse' ])
def main(): if not ('__file__' in locals() or '__file__' in globals()): print('running interactively, exiting.') sys.exit(0) # parse arguments parser, args = parse_args() args_dict = vars(args) # exit and show help if no arguments provided at all if len(sys.argv) == 1: parser.print_help() sys.exit(0) # check for required arguments manually (because of argparse issue) required = ('--input-dir', '--input-axes', '--norm-pmin', '--norm-pmax', '--model-basedir', '--model-name', '--output-dir') for r in required: dest = r[2:].replace('-', '_') if args_dict[dest] is None: parser.print_usage(file=sys.stderr) print("%s: error: the following arguments are required: %s" % (parser.prog, r), file=sys.stderr) sys.exit(1) # show effective arguments (including defaults) if not args.quiet: print('Arguments') print('---------') pprint(args_dict) print() sys.stdout.flush() # logging function log = (lambda *a, **k: None) if args.quiet else tqdm.write # get list of input files and exit if there are none file_list = list(Path(args.input_dir).glob(args.input_pattern)) if len(file_list) == 0: log("No files to process in '%s' with pattern '%s'." % (args.input_dir, args.input_pattern)) sys.exit(0) # delay imports after checking to all required arguments are provided from tifffile import imread, imsave from csbdeep.utils.tf import keras_import K = keras_import('backend') from csbdeep.models import CARE from csbdeep.data import PercentileNormalizer sys.stdout.flush() sys.stderr.flush() # limit gpu memory if args.gpu_memory_limit is not None: from csbdeep.utils.tf import limit_gpu_memory limit_gpu_memory(args.gpu_memory_limit) # create CARE model and load weights, create normalizer K.clear_session() model = CARE(config=None, name=args.model_name, basedir=args.model_basedir) if args.model_weights is not None: print("Loading network weights from '%s'." % args.model_weights) model.load_weights(args.model_weights) normalizer = PercentileNormalizer(pmin=args.norm_pmin, pmax=args.norm_pmax, do_after=args.norm_undo) n_tiles = args.n_tiles if n_tiles is not None and len(n_tiles) == 1: n_tiles = n_tiles[0] processed = [] # process all files for file_in in tqdm(file_list, disable=args.quiet or (n_tiles is not None and np.prod(n_tiles) > 1)): # construct output file name file_out = Path(args.output_dir) / args.output_name.format( file_path=str(file_in.relative_to(args.input_dir).parent), file_name=file_in.stem, file_ext=file_in.suffix, model_name=args.model_name, model_weights=Path(args.model_weights).stem if args.model_weights is not None else None) # checks (file_in.suffix.lower() in ('.tif', '.tiff') and file_out.suffix.lower() in ('.tif', '.tiff')) or _raise( ValueError('only tiff files supported.')) # load and predict restored image img = imread(str(file_in)) restored = model.predict(img, axes=args.input_axes, normalizer=normalizer, n_tiles=n_tiles) # restored image could be multi-channel even if input image is not axes_out = axes_check_and_normalize(args.input_axes) if restored.ndim > img.ndim: assert restored.ndim == img.ndim + 1 assert 'C' not in axes_out axes_out += 'C' # convert data type (if necessary) restored = restored.astype(np.dtype(args.output_dtype), copy=False) # save to disk if not args.dry_run: file_out.parent.mkdir(parents=True, exist_ok=True) if args.imagej_tiff: save_tiff_imagej_compatible(str(file_out), restored, axes_out) else: imsave(str(file_out), restored) processed.append((file_in, file_out)) # print summary of processed files if not args.quiet: sys.stdout.flush() sys.stderr.flush() n_processed = len(processed) len_processed = len(str(n_processed)) log('Finished processing %d %s' % (n_processed, 'files' if n_processed > 1 else 'file')) log('-' * (26 + len_processed if n_processed > 1 else 26)) for i, (file_in, file_out) in enumerate(processed): len_file = max(len(str(file_in)), len(str(file_out))) log(('{:>%d}. in : {:>%d}' % (len_processed, len_file)).format( 1 + i, str(file_in))) log(('{:>%d} out: {:>%d}' % (len_processed, len_file)).format( '', str(file_out)))
def n2v_flim(project, n2v_num_pix=32): results_file = os.path.join(project, 'fit_results.hdf5') X, groups, mask = extract_results(results_file) data_shape = np.shape(X) print(data_shape) mean, std = np.mean(X), np.std(X) X = normalize(X, mean, std) XA = X #augment_data(X) X_val = X[0:10,...] # We concatenate an extra channel filled with zeros. It will be internally used for the masking. Y = np.concatenate((XA, np.zeros(XA.shape)), axis=-1) Y_val = np.concatenate((X_val.copy(), np.zeros(X_val.shape)), axis=-1) n_x = X.shape[1] n_chan = X.shape[-1] manipulate_val_data(X_val, Y_val, num_pix=n_x*n_x*2/n2v_num_pix , shape=(n_x, n_x)) # You can increase "train_steps_per_epoch" to get even better results at the price of longer computation. config = Config('SYXC', n_channel_in=n_chan, n_channel_out=n_chan, unet_kern_size = 5, unet_n_depth = 2, train_steps_per_epoch=200, train_loss='mae', train_epochs=35, batch_norm = False, train_scheme = 'Noise2Void', train_batch_size = 128, n2v_num_pix = n2v_num_pix, n2v_patch_shape = (n2v_num_pix, n2v_num_pix), n2v_manipulator = 'uniform_withCP', n2v_neighborhood_radius='5') vars(config) model = CARE(config, 'n2v_model', basedir=project) history = model.train(XA, Y, validation_data=(X_val,Y_val)) model.load_weights(name='weights_best.h5') output_project = project.replace('.flimfit','-n2v.flimfit') if os.path.exists(output_project) : shutil.rmtree(output_project) shutil.copytree(project, output_project) output_file = os.path.join(output_project, 'fit_results.hdf5') X_pred = np.zeros(X.shape) for i in range(X.shape[0]): X_pred[i,...] = denormalize(model.predict(X[i], axes='YXC',normalizer=None), mean, std) X_pred[mask] = np.NaN insert_results(output_file, X_pred, groups)