def test_load_and_export_TF(): model_path = path_model3d() model = StarDist3D(None, name=model_path.name, basedir=str(model_path.parent)) model.export_TF(single_output=True, upsample_grid=False) model.export_TF(single_output=True, upsample_grid=True)
def _model3d(): from utils import path_model3d from stardist.models import StarDist3D model_path = path_model3d() return StarDist3D(None, name=model_path.name, basedir=str(model_path.parent))
def test_model(tmpdir, n_rays, grid, n_channel, backbone): img = circle_image(shape=(64, 80, 96)) imgs = np.repeat(img[np.newaxis], 3, axis=0) if n_channel is not None: imgs = np.repeat(imgs[..., np.newaxis], n_channel, axis=-1) else: n_channel = 1 X = imgs + .6 * np.random.uniform(0, 1, imgs.shape) Y = (imgs if imgs.ndim == 4 else imgs[..., 0]).astype(int) conf = Config3D( backbone=backbone, rays=n_rays, grid=grid, n_channel_in=n_channel, use_gpu=False, train_epochs=1, train_steps_per_epoch=2, train_batch_size=2, train_loss_weights=(4, 1), train_patch_size=(48, 64, 64), ) model = StarDist3D(conf, name='stardist', basedir=str(tmpdir)) model.train(X, Y, validation_data=(X[:2], Y[:2])) ref = model.predict(X[0]) res = model.predict(X[0], n_tiles=((1, 2, 3) if X[0].ndim == 3 else (1, 2, 3, 1)))
def read_model(model_name: str) -> StarDist3D: """Read StarDist model.""" with HidePrint(): return StarDist3D(None, name=model_name, basedir=str( pl.Path(__file__).parents[1].joinpath('models')))
def print_receptive_fields(): backbone = "unet" for n_depth in (1, 2, 3): for grid in ((1, 1, 1), (2, 2, 2)): conf = Config3D(backbone=backbone, grid=grid, unet_n_depth=n_depth) model = StarDist3D(conf, None, None) fov = model._compute_receptive_field() print( f"backbone: {backbone} \t n_depth: {n_depth} \t grid {grid} -> fov: {fov}" ) backbone = "resnet" for grid in ((1, 1, 1), (2, 2, 2)): conf = Config3D(backbone=backbone, grid=grid) model = StarDist3D(conf, None, None) fov = model._compute_receptive_field() print(f"backbone: {backbone} \t grid {grid} -> fov: {fov}")
def test_model(tmpdir, n_rays, grid, n_channel, backbone): img = circle_image(shape=(64, 80, 96)) imgs = np.repeat(img[np.newaxis], 3, axis=0) if n_channel is not None: imgs = np.repeat(imgs[..., np.newaxis], n_channel, axis=-1) else: n_channel = 1 X = imgs + .6 * np.random.uniform(0, 1, imgs.shape) Y = (imgs if imgs.ndim == 4 else imgs[..., 0]).astype(int) conf = Config3D( backbone=backbone, rays=n_rays, grid=grid, n_channel_in=n_channel, use_gpu=False, train_epochs=1, train_steps_per_epoch=2, train_batch_size=2, train_loss_weights=(4, 1), train_patch_size=(48, 64, 64), ) model = StarDist3D(conf, name='stardist', basedir=str(tmpdir)) model.train(X, Y, validation_data=(X[:2], Y[:2])) ref = model.predict(X[0]) res = model.predict(X[0], n_tiles=((1, 2, 3) if X[0].ndim == 3 else (1, 2, 3, 1))) # assert all(np.allclose(u,v) for u,v in zip(ref,res)) # ask to train only with foreground patches when there are none # include a constant label image that must trigger a warning conf.train_foreground_only = 1 conf.train_steps_per_epoch = 1 _X = X[:2] _Y = [np.zeros_like(Y[0]), np.ones_like(Y[1])] with pytest.warns(UserWarning): StarDist3D(conf, name='stardist', basedir=None).train(_X, _Y, validation_data=(X[-1:], Y[-1:]))
def test_load_and_export_TF(): model_path = path_model3d() model = StarDist3D(None, name=model_path.name, basedir=str(model_path.parent)) assert any(g > 1 for g in model.config.grid) # model.export_TF(single_output=False, upsample_grid=False) # model.export_TF(single_output=False, upsample_grid=True) model.export_TF(single_output=True, upsample_grid=False) model.export_TF(single_output=True, upsample_grid=True)
def test_load_and_predict_with_overlap(): model_path = path_model3d() model = StarDist3D(None, name=model_path.name, basedir=str(model_path.parent)) img, mask = real_image3d() x = normalize(img,1,99.8) prob, dist = model.predict(x, n_tiles=(1,2,2)) assert prob.shape == dist.shape[:3] assert model.config.n_rays == dist.shape[-1] labels, _ = model.predict_instances(x, nms_thresh = .5, overlap_label = -3) assert np.min(labels) == -3 return model, labels
def test_load_and_predict(): model_path = path_model3d() model = StarDist3D(None, name=model_path.name, basedir=str(model_path.parent)) img, mask = real_image3d() x = normalize(img,1,99.8) prob, dist = model.predict(x, n_tiles=(1,2,2)) assert prob.shape == dist.shape[:3] assert model.config.n_rays == dist.shape[-1] labels, _ = model.predict_instances(x) assert labels.shape == img.shape[:3] stats = matching(mask, labels, thresh=0.5) assert (stats.fp, stats.tp, stats.fn) == (0, 30, 21) return model, labels
def test_predict_dense_sparse(): model_path = path_model3d() model = StarDist3D(None, name=model_path.name, basedir=str(model_path.parent)) img, mask = real_image3d() x = normalize(img, 1, 99.8) labels1, res1 = model.predict_instances(x, n_tiles=(1, 2, 2), sparse=False) labels2, res2 = model.predict_instances(x, n_tiles=(1, 2, 2), sparse=True) assert np.allclose(labels1, labels2) assert all( np.allclose(res1[k], res2[k]) for k in set(res1.keys()).union(set(res2.keys()))) return labels2, labels2
def test_foreground_warning(): # ask to train only with foreground patches when there are none # include a constant label image that must trigger a warning conf = Config3D( n_rays=32, train_patch_size=(16, 32, 16), train_foreground_only=1, train_steps_per_epoch=1, train_epochs=1, train_batch_size=2, ) X, Y = np.ones((2, 32, 48, 16), np.float32), np.ones((2, 32, 48, 16), np.uint16) with pytest.warns(UserWarning): StarDist3D(conf, None, None).train(X, Y, validation_data=(X[-1:], Y[-1:]))
def test_optimize_thresholds(): model_path = path_model3d() model = StarDist3D(None, name=model_path.name, basedir=str(model_path.parent)) img, mask = real_image3d() x = normalize(img,1,99.8) def _opt(model): return model.optimize_thresholds([x],[mask], nms_threshs = [.3,.5], iou_threshs = [.3,.5], optimize_kwargs = dict(tol=1e-1), save_to_json = False) t1 = _opt(model) # enforce implicit tiling model.config.train_batch_size = 1 model.config.train_patch_size = tuple(s-1 for s in x.shape) t2 = _opt(model) assert all(np.allclose(t1[k],t2[k]) for k in t1.keys()) return model
def test_model(n_rays, grid): img = circle_image(shape = (64,80,96)) imgs = np.repeat(img[np.newaxis],10, axis = 0) X = imgs+.6*np.random.uniform(0,1,imgs.shape) Y = imgs.astype(np.uint16) conf = Config3D ( rays = n_rays, grid = grid, use_gpu = False, train_epochs = 1, train_steps_per_epoch = 2, train_loss_weights = (4,1), train_patch_size = (48,64,64), n_channel_in = 1) with tempfile.TemporaryDirectory() as tmp: model = StarDist3D(conf, name='stardist', basedir=tmp) model.train(X, Y, validation_data=(X[:2],Y[:2]))
def train_model(x_train, y_train, x_val, y_val, save_path, patch_size, anisotropy, n_rays=96): rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy) # make the model config # copied from the stardist training notebook, this is a very weird line ... use_gpu = False and gputools_available() # predict on subsampled image for increased efficiency grid = tuple(1 if a > 1.5 else 2 for a in anisotropy) config = Config3D(rays=rays, grid=grid, use_gpu=use_gpu, n_channel_in=1, train_patch_size=patch_size, anisotropy=anisotropy) if use_gpu: print("Using a GPU for training") # limit gpu memory from csbdeep.utils.tf import limit_gpu_memory limit_gpu_memory(0.8) else: print("GPU not found, using the CPU for training") save_root, save_name = os.path.split(save_path) os.makedirs(save_root, exist_ok=True) model = StarDist3D(config, name=save_name, basedir=save_root) model.train(x_train, y_train, validation_data=(x_val, y_val), augmenter=augmenter) optimal_parameters = model.optimize_thresholds(x_val, y_val) return optimal_parameters
def test_model(tmpdir, n_rays, grid, n_channel, backbone, workers, use_sequence): img = circle_image(shape=(64, 80, 96)) imgs = np.repeat(img[np.newaxis], 3, axis=0) if n_channel is not None: imgs = np.repeat(imgs[..., np.newaxis], n_channel, axis=-1) else: n_channel = 1 X = imgs + .6 * np.random.uniform(0, 1, imgs.shape) Y = (imgs if imgs.ndim == 4 else imgs[..., 0]).astype(int) if use_sequence: X, Y = NumpySequence(X), NumpySequence(Y) conf = Config3D( backbone=backbone, rays=n_rays, grid=grid, n_channel_in=n_channel, use_gpu=False, train_epochs=1, train_steps_per_epoch=1, train_batch_size=2, train_loss_weights=(4, 1), train_patch_size=(48, 64, 32), train_sample_cache=not use_sequence, ) model = StarDist3D(conf, name='stardist', basedir=str(tmpdir)) model.train(X, Y, validation_data=(X[:2], Y[:2]), workers=workers) ref = model.predict(X[0]) res = model.predict(X[0], n_tiles=((1, 2, 3) if X[0].ndim == 3 else (1, 2, 3, 1))) # deactivate as order of labels might not be the same # assert all(np.allclose(u,v) for u,v in zip(ref,res)) return model
def run_prediction(image_files, model_path, root, prediction_folder): # load the model model_root, model_name = os.path.split(model_path.rstrip('/')) model = StarDist3D(None, name=model_name, basedir=model_root) res_folder = os.path.join(root, prediction_folder) os.makedirs(res_folder, exist_ok=True) # normalization parameters: lower and upper percentile used for image normalization # maybe these should be exposed lower_percentile = 1 upper_percentile = 99.8 ax_norm = (0, 1, 2) for im_file in tqdm(image_files, desc="run stardist prediction"): im = imageio.volread(im_file) im = normalize(im, lower_percentile, upper_percentile, axis=ax_norm) pred, _ = model.predict_instances(im) im_name = os.path.split(im_file)[1] save_path = os.path.join(res_folder, im_name) imageio.imsave(save_path, pred)
def Train(self): BinaryName = 'BinaryMask/' RealName = 'RealMask/' Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif')) Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True) Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True) RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif')) ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif')) ValRealMask = sorted( glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif')) print('Instance segmentation masks:', len(RealMask)) if len(RealMask) == 0: print('Making labels') Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) for fname in Mask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = label(image) imwrite((self.BaseDir + '/' + RealName + Name + '.tif'), Binaryimage.astype('uint16')) Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) print('Semantic segmentation masks:', len(Mask)) if len(Mask) == 0: print('Generating Binary images') RealfilesMask = sorted( glob.glob(self.BaseDir + '/' + RealName + '*tif')) for fname in RealfilesMask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = image > 0 imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'), Binaryimage.astype('uint16')) if self.GenerateNPZ: raw_data = RawData.from_folder( basepath=self.BaseDir, source_dirs=['Raw/'], target_dir='BinaryMask/', axes='ZYX', ) X, Y, XY_axes = create_patches( raw_data=raw_data, patch_size=(self.PatchZ, self.PatchY, self.PatchX), n_patches_per_image=self.n_patches_per_image, save_file=self.BaseDir + self.NPZfilename + '.npz', ) # Training UNET model if self.TrainUNET: print('Training UNET model') load_path = self.BaseDir + self.NPZfilename + '.npz' (X, Y), (X_val, Y_val), axes = load_training_data(load_path, validation_split=0.1, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] config = Config(axes, n_channel_in, n_channel_out, unet_n_depth=self.depth, train_epochs=self.epochs, train_batch_size=self.batch_size, unet_n_first=self.startfilter, train_loss='mse', unet_kern_size=self.kern_size, train_learning_rate=self.learning_rate, train_reduce_lr={ 'patience': 5, 'factor': 0.5 }) print(config) vars(config) model = CARE(config, name='UNET' + self.model_name, basedir=self.model_dir) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') model.load_weights(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5') history = model.train(X, Y, validation_data=(X_val, Y_val)) print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) if self.TrainSTAR: print('Training StarDistModel model with', self.backbone, 'backbone') self.axis_norm = (0, 1, 2) if self.CroppedLoad == False: assert len(Raw) > 1, "not enough training data" print(len(Raw)) rng = np.random.RandomState(42) ind = rng.permutation(len(Raw)) X_train = list(map(ReadFloat, Raw)) Y_train = list(map(ReadInt, RealMask)) self.Y = [ label(DownsampleData(y, self.DownsampleFactor)) for y in tqdm(Y_train) ] self.X = [ normalize(DownsampleData(x, self.DownsampleFactor), 1, 99.8, axis=self.axis_norm) for x in tqdm(X_train) ] n_val = max(1, int(round(0.15 * len(ind)))) ind_train, ind_val = ind[:-n_val], ind[-n_val:] self.X_val, self.Y_val = [self.X[i] for i in ind_val ], [self.Y[i] for i in ind_val] self.X_trn, self.Y_trn = [self.X[i] for i in ind_train ], [self.Y[i] for i in ind_train] print('number of images: %3d' % len(self.X)) print('- training: %3d' % len(self.X_trn)) print('- validation: %3d' % len(self.X_val)) if self.CroppedLoad: self.X_trn = self.DataSequencer(Raw, self.axis_norm, Normalize=True, labelMe=False) self.Y_trn = self.DataSequencer(RealMask, self.axis_norm, Normalize=False, labelMe=True) self.X_val = self.DataSequencer(ValRaw, self.axis_norm, Normalize=True, labelMe=False) self.Y_val = self.DataSequencer(ValRealMask, self.axis_norm, Normalize=False, labelMe=True) self.train_sample_cache = False print(Config3D.__doc__) anisotropy = (1, 1, 1) rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy) if self.backbone == 'resnet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, resnet_n_blocks=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', resnet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, resnet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1) if self.backbone == 'unet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, unet_n_depth=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', unet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, unet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1, train_sample_cache=False) print(conf) vars(conf) Starmodel = StarDist3D(conf, name=self.model_name, basedir=self.model_dir) print( Starmodel._axes_tile_overlap('ZYX'), os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5')) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_last.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_best.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_best.h5') historyStar = Starmodel.train(self.X_trn, self.Y_trn, validation_data=(self.X_val, self.Y_val), epochs=self.epochs) print(sorted(list(historyStar.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(historyStar, ['loss', 'val_loss'], [ 'dist_relevant_mae', 'val_dist_relevant_mae', 'dist_relevant_mse', 'val_dist_relevant_mse' ])
def run_StarDist( inputImagePath, z_count, t_count, model_selection, probThreshold, nmsThreshold, normalizationLow, normalizationHigh, output_type, resultPath): print('------------------------------------------') print(' StarDist Virtual Environment') print('------------------------------------------') print(f' inputImagePath = {inputImagePath}') print(f' z_count = {z_count}') print(f' t_count = {t_count}') print(f' model_selection = {model_selection}') print(f' probThreshold = {probThreshold}') print(f' nmsThreshold = {nmsThreshold}') print(f' normalizationLow = {normalizationLow}') print(f'normalizationHigh = {normalizationHigh}') print(f' outputType = {output_type}') print(f' resultPath = {resultPath}') # Limit GPU memory usage limit_gpu_memory(fraction=None, allow_growth=True) # Get the path of the folder that contains this python script script_folder = pathlib.Path(__file__).resolve().parent logger.info(f'Script Folder = {script_folder}') # Get the model selections model_dict = {0: '2D_demo', 1: '2D_fluor_nuc', 2: '2D_dsb2018', 3: '3D_demo'} if model_selection not in model_dict: logger.warn('Selection is not available, use 2D_demo instead') model_name = model_dict[model_selection] # Load StarDist model assuming the 3D_demo and 2D_demo folders # are both in `script_folder` if z_count > 1: # input image is 3D or 3D+T logger.warn('Input is a 3D/3D+T image, use 3D_demo model') # Use 3D model for 3D image model = StarDist3D(None, name='3D_demo', basedir=script_folder) # Set 3D block size tile_shape = (50, 256, 256) # Check if input is a time-lapse image if t_count > 1: axes = 'YXZT' else: axes = 'YXZ' elif z_count == 1: # Use 2D model for 2D image model = StarDist2D(None, name=model_name, basedir=script_folder) # Set 2D tile size tile_shape = (512, 512) # Check if input is a time-lapse image if t_count > 1: axes = 'TYX' else: axes = 'YX' else: raise ValueError('Z count must be positive') # Load input image image = tifffile.imread(inputImagePath) dtype = image.dtype # Current limitation: input and output should have the same depth if dtype == np.uint8: logger.warn('Label image will be saved in 8bit') # Not a time-lapse if t_count == 1: image = image[np.newaxis] # Create output labeled image labels = np.empty_like(image, dtype=dtype) n_tiles = [i // t + 1 for t, i in zip(tile_shape, image[0].shape)] # Get thresholds prob_thresh = np.clip(probThreshold, 0.0, 1.0) nms_thresh = np.clip(nmsThreshold, 0.0, 1.0) # Use default thresholds optimized for the StarDist model when both # thresholds are set as 0 if prob_thresh == 0.0 and nms_thresh == 0.0: logger.warn( 'Use default thresholds of the StarDist model when both ' 'thresholds are set as 0.') prob_thresh = nms_thresh = None logger.info(f'probThreshold = {prob_thresh}, nmsThreshold = {nms_thresh}') # Get Normalization Percentile p_min = np.clip(normalizationLow, 0.0, 100.0) p_max = np.clip(normalizationHigh, 0.0, 100.0) # Use default normalization for the StarDist model when p_min >= p_max if p_min >= p_max: logger.warn( 'Use default normalization of the StarDist model ' 'when p_min >= p_max.') p_min, p_max = 2, 99.9 logger.info(f'normalizationLow = {p_min}, normalizationHigh = {p_max}') # Apply StarDist model for t in range(t_count): labels[t] = model.predict_instances( normalize(image[t], p_min=p_min, p_max=p_max), prob_thresh=prob_thresh, nms_thresh=nms_thresh, n_tiles=n_tiles, show_tile_progress=False)[0].astype(dtype) # Convert labeled mask to binary mask if (output_type == 1): # Add two pixel gap between neighboring masks if (z_count > 1): addOnePixelGap_3D(labels[t]) else: addOnePixelGap_2D(labels[t]) # Convert to binary mask val = 255 if (labels[t].dtype.type == np.uint16): val = 65535 labels[t] = (labels[t] > 0) * val # Not a time-lapse if t_count == 1: labels = labels[0] # Save the labeled image tifffile.imwrite(resultPath, labels, photometric='minisblack', metadata={'axes': axes})
def __len__(self): return self.n if __name__ == '__main__': parser = argparse.ArgumentParser(description='') parser.add_argument("-n", type=int, default=10) parser.add_argument("-s", "--size", type=int, default=256) parser.add_argument("--steps", type=int, default=20) parser.add_argument("--nocache", action='store_true') args = parser.parse_args() X, Y = LargeSequence(20, size=args.size), LargeSequence(20, size=args.size) conf = Config3D(n_rays=32, backbone="unet", unet_n_depth=1, train_epochs=args.n, train_steps_per_epoch=args.steps, train_batch_size=1, train_patch_size=(min(96, args.size), ) * 3, train_sample_cache=not args.nocache) model = StarDist3D(conf, None, None) model.prepare_for_training() model.callbacks.append( LambdaCallback(on_epoch_end=lambda a, b: print_memory())) model.train(X, Y, validation_data=(X[0][np.newaxis], Y[0][np.newaxis]))
def run(self, workspace): images = workspace.image_set x = images.get_image(self.x_name.value) dimensions = x.dimensions x_data = x.pixel_data # Validate some settings if self.model.value in (GREY_1, GREY_2) and x.multichannel: raise ValueError( "Color images are not supported by this model. Please provide greyscale images." ) elif self.model.value == COLOR_1 and not x.multichannel: raise ValueError( "Greyscale images are not supported by this model. Please provide a color overlay." ) if self.model.value != MODEL_CUSTOM: if x.volumetric: raise NotImplementedError( "StarDist's inbuilt models do not currently support 3D images" ) model = StarDist2D.from_pretrained(self.model.value) else: model_directory, model_name = os.path.split( self.model_directory.get_absolute_path()) if x.volumetric: from stardist.models import StarDist3D model = StarDist3D(config=None, basedir=model_directory, name=model_name) else: model = StarDist2D(config=None, basedir=model_directory, name=model_name) tiles = None if self.tile_image.value: tiles = [] if x.volumetric: tiles += [1] tiles += [self.n_tiles_x.value, self.n_tiles_y.value] # Handle colour channels tiles += [1] * max(0, x.pixel_data.ndim - len(tiles)) print(x.pixel_data.shape, x.pixel_data.ndim, tiles) if not self.save_probabilities.value: # Probabilities aren't wanted, things are simple data = model.predict_instances(normalize(x.pixel_data), return_predict=False, n_tiles=tiles) y_data = data[0] else: data, probs = model.predict_instances(normalize(x.pixel_data), return_predict=True, sparse=False, n_tiles=tiles) y_data = data[0] # Scores aren't at the same resolution as the input image. # We need to slightly resize to match the original image. size_corrected = resize(probs[0], y_data.shape) prob_image = Image( size_corrected, parent_image=x.parent_image, convert=False, dimensions=len(size_corrected.shape), ) workspace.image_set.add(self.probabilities_name.value, prob_image) if self.show_window: workspace.display_data.probabilities = size_corrected y = Objects() y.segmented = y_data y.parent_image = x.parent_image objects = workspace.object_set objects.add_objects(y, self.y_name.value) self.add_measurements(workspace) if self.show_window: workspace.display_data.x_data = x_data workspace.display_data.y_data = y_data workspace.display_data.dimensions = dimensions
from vollseg.OptimizeThreshold import OptimizeThreshold os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" from pathlib import Path # In[2]: BaseDir = '/data/u934/service_imagerie/v_kapoor/CurieTrainingDatasets/MouseClaudia/AugmentedGreenCell3D/' Model_Dir = '/data/u934/service_imagerie/v_kapoor/CurieDeepLearningModels/MouseClaudia/' SaveDir = '/data/u934/service_imagerie/v_kapoor/CurieTrainingDatasets/MouseClaudia/' StardistModelName = 'ScipyDeepGreenCells' UNETModelName = 'UNETScipyDeepGreenCells' NoiseModel = None Starmodel = StarDist3D(config=None, name=StardistModelName, basedir=Model_Dir) UnetModel = CARE(config=None, name=UNETModelName, basedir=Model_Dir) # In[3]: #Number of tiles to break the image into for applying the prediction to fit in the computer memory n_tiles = (1, 2, 2) #Use Probability map = True or distance map = False as the image to perform watershed on UseProbability = False # In[ ]: Raw = sorted(glob.glob(BaseDir + '/Raw/' + '*.tif')) RealMask = sorted(glob.glob(BaseDir + '/RealMask/' + '*.tif')) X = list(map(imread, Raw))
parser.add_argument('-c', '--channel', type=str, help="channel") parser.add_argument('-s', '--scale', type=str, help="scale") args = parser.parse_args() print("reading ...", args.input, args.channel + '/' + args.scale) im = z5py.File(args.input, use_zarr_format=False) img = im[args.channel + '/' + args.scale][:, :, :] n_tiles = tuple(int(np.ceil(s / 128)) for s in img.shape) print("estimated tiling:", n_tiles) print("normalizing input...") img_normed = normalize(img, 4, 99.8) model = StarDist3D(None, name=args.model, basedir=args.model) print("predicting...") # the affinity based labels label_starfinity, res_dict = model.predict_instances(img_normed, n_tiles=n_tiles, affinity=True, affinity_thresh=0.1, verbose=True) # the normal stardist labels are implicitly calculated and # can be accessed from the results dict label_stardist = res_dict["markers"] print("saving...")
def get_model(self) -> StarDistBase: return StarDist3D.from_pretrained(self.parameters["model"])
def _test_model_multiclass(n_classes=1, classes="auto", n_channel=None, basedir=None, epochs=20, batch_size=1): from skimage.measure import regionprops img, mask = real_image3d() img = normalize(img, 1, 99.8) if n_channel is not None: img = np.repeat(img[..., np.newaxis], n_channel, axis=-1) else: n_channel = 1 X, Y = [img, img, img], [mask, mask, mask] conf = Config3D( n_rays=32, grid=(2, 1, 2), n_channel_in=n_channel, n_classes=n_classes, use_gpu=False, train_epochs=1, train_steps_per_epoch=10, train_batch_size=batch_size, train_loss_weights=(1., .2) if n_classes is None else (1, .2, 1.), train_patch_size=(24, 32, 32), ) # efine some classes according to the areas if n_classes is not None and n_classes > 1 and classes == "auto": regs = regionprops(mask) areas = tuple(r.area for r in regs) inds = np.argsort(areas) ss = tuple( slice(n * len(regs) // n_classes, (n + 1) * len(regs) // n_classes) for n in range(n_classes)) classes = {} for i, s in enumerate(ss): for j in inds[s]: classes[regs[j].label] = i + 1 classes = (classes, ) * len(X) model = StarDist3D(conf, name=None if basedir is None else "stardist", basedir=str(basedir)) val_classes = {k: 1 for k in set(mask[mask > 0])} s = model.train(X, Y, classes=classes, epochs=epochs, validation_data=(X[:1], Y[:1]) if n_classes is None else (X[:1], Y[:1], (val_classes, ))) labels, res = model.predict_instances(img) # return model, X,Y, classes, labels, res img = np.tile(img, (4, 2, 2) if img.ndim == 3 else (4, 2, 2, 1)) kwargs = dict(prob_thresh=.5) labels1, res1 = model.predict_instances(img, **kwargs) labels2, res2 = model.predict_instances(img, sparse=True, **kwargs) labels3, res3 = model.predict_instances_big( img, axes="ZYX" if img.ndim == 3 else "ZYXC", block_size=96, min_overlap=8, context=8, **kwargs) assert np.allclose(labels1, labels2) assert all([ np.allclose(res1[k], res2[k]) for k in set(res1.keys()).union(set(res2.keys())) if isinstance(res1[k], np.ndarray) ]) return model, img, res1, res2, res3
def _parse(n_classes, classes): model = StarDist3D(Config3D(n_classes=n_classes), None, None) classes = model._parse_classes_arg(classes, length=1) return classes