示例#1
0
def test_load_and_export_TF():
    model_path = path_model3d()
    model = StarDist3D(None,
                       name=model_path.name,
                       basedir=str(model_path.parent))
    model.export_TF(single_output=True, upsample_grid=False)
    model.export_TF(single_output=True, upsample_grid=True)
示例#2
0
def _model3d():
    from utils import path_model3d
    from stardist.models import StarDist3D
    model_path = path_model3d()
    return StarDist3D(None,
                      name=model_path.name,
                      basedir=str(model_path.parent))
示例#3
0
def test_model(tmpdir, n_rays, grid, n_channel, backbone):
    img = circle_image(shape=(64, 80, 96))
    imgs = np.repeat(img[np.newaxis], 3, axis=0)

    if n_channel is not None:
        imgs = np.repeat(imgs[..., np.newaxis], n_channel, axis=-1)
    else:
        n_channel = 1

    X = imgs + .6 * np.random.uniform(0, 1, imgs.shape)
    Y = (imgs if imgs.ndim == 4 else imgs[..., 0]).astype(int)

    conf = Config3D(
        backbone=backbone,
        rays=n_rays,
        grid=grid,
        n_channel_in=n_channel,
        use_gpu=False,
        train_epochs=1,
        train_steps_per_epoch=2,
        train_batch_size=2,
        train_loss_weights=(4, 1),
        train_patch_size=(48, 64, 64),
    )

    model = StarDist3D(conf, name='stardist', basedir=str(tmpdir))
    model.train(X, Y, validation_data=(X[:2], Y[:2]))
    ref = model.predict(X[0])
    res = model.predict(X[0],
                        n_tiles=((1, 2, 3) if X[0].ndim == 3 else
                                 (1, 2, 3, 1)))
示例#4
0
def read_model(model_name: str) -> StarDist3D:
    """Read StarDist model."""
    with HidePrint():
        return StarDist3D(None,
                          name=model_name,
                          basedir=str(
                              pl.Path(__file__).parents[1].joinpath('models')))
示例#5
0
def print_receptive_fields():
    backbone = "unet"
    for n_depth in (1, 2, 3):
        for grid in ((1, 1, 1), (2, 2, 2)):
            conf = Config3D(backbone=backbone, grid=grid, unet_n_depth=n_depth)
            model = StarDist3D(conf, None, None)
            fov = model._compute_receptive_field()
            print(
                f"backbone: {backbone} \t n_depth: {n_depth} \t grid {grid} -> fov: {fov}"
            )
    backbone = "resnet"
    for grid in ((1, 1, 1), (2, 2, 2)):
        conf = Config3D(backbone=backbone, grid=grid)
        model = StarDist3D(conf, None, None)
        fov = model._compute_receptive_field()
        print(f"backbone: {backbone} \t grid {grid} -> fov: {fov}")
示例#6
0
def test_model(tmpdir, n_rays, grid, n_channel, backbone):
    img = circle_image(shape=(64, 80, 96))
    imgs = np.repeat(img[np.newaxis], 3, axis=0)

    if n_channel is not None:
        imgs = np.repeat(imgs[..., np.newaxis], n_channel, axis=-1)
    else:
        n_channel = 1

    X = imgs + .6 * np.random.uniform(0, 1, imgs.shape)
    Y = (imgs if imgs.ndim == 4 else imgs[..., 0]).astype(int)

    conf = Config3D(
        backbone=backbone,
        rays=n_rays,
        grid=grid,
        n_channel_in=n_channel,
        use_gpu=False,
        train_epochs=1,
        train_steps_per_epoch=2,
        train_batch_size=2,
        train_loss_weights=(4, 1),
        train_patch_size=(48, 64, 64),
    )

    model = StarDist3D(conf, name='stardist', basedir=str(tmpdir))
    model.train(X, Y, validation_data=(X[:2], Y[:2]))
    ref = model.predict(X[0])
    res = model.predict(X[0],
                        n_tiles=((1, 2, 3) if X[0].ndim == 3 else
                                 (1, 2, 3, 1)))
    # assert all(np.allclose(u,v) for u,v in zip(ref,res))

    # ask to train only with foreground patches when there are none
    # include a constant label image that must trigger a warning
    conf.train_foreground_only = 1
    conf.train_steps_per_epoch = 1
    _X = X[:2]
    _Y = [np.zeros_like(Y[0]), np.ones_like(Y[1])]
    with pytest.warns(UserWarning):
        StarDist3D(conf, name='stardist',
                   basedir=None).train(_X,
                                       _Y,
                                       validation_data=(X[-1:], Y[-1:]))
示例#7
0
def test_load_and_export_TF():
    model_path = path_model3d()
    model = StarDist3D(None,
                       name=model_path.name,
                       basedir=str(model_path.parent))
    assert any(g > 1 for g in model.config.grid)
    # model.export_TF(single_output=False, upsample_grid=False)
    # model.export_TF(single_output=False, upsample_grid=True)
    model.export_TF(single_output=True, upsample_grid=False)
    model.export_TF(single_output=True, upsample_grid=True)
示例#8
0
def test_load_and_predict_with_overlap():
    model_path = path_model3d()
    model = StarDist3D(None, name=model_path.name, basedir=str(model_path.parent))
    img, mask = real_image3d()
    x = normalize(img,1,99.8)
    prob, dist = model.predict(x, n_tiles=(1,2,2))
    assert prob.shape == dist.shape[:3]
    assert model.config.n_rays == dist.shape[-1]
    labels, _ = model.predict_instances(x, nms_thresh = .5,
                                        overlap_label = -3)
    assert np.min(labels) == -3
    return model, labels
示例#9
0
def test_load_and_predict():
    model_path = path_model3d()
    model = StarDist3D(None, name=model_path.name, basedir=str(model_path.parent))
    img, mask = real_image3d()
    x = normalize(img,1,99.8)
    prob, dist = model.predict(x, n_tiles=(1,2,2))
    assert prob.shape == dist.shape[:3]
    assert model.config.n_rays == dist.shape[-1]
    labels, _ = model.predict_instances(x)
    assert labels.shape == img.shape[:3]
    stats = matching(mask, labels, thresh=0.5)
    assert (stats.fp, stats.tp, stats.fn) == (0, 30, 21)
    return model, labels
示例#10
0
def test_predict_dense_sparse():
    model_path = path_model3d()
    model = StarDist3D(None,
                       name=model_path.name,
                       basedir=str(model_path.parent))
    img, mask = real_image3d()
    x = normalize(img, 1, 99.8)
    labels1, res1 = model.predict_instances(x, n_tiles=(1, 2, 2), sparse=False)
    labels2, res2 = model.predict_instances(x, n_tiles=(1, 2, 2), sparse=True)
    assert np.allclose(labels1, labels2)
    assert all(
        np.allclose(res1[k], res2[k])
        for k in set(res1.keys()).union(set(res2.keys())))
    return labels2, labels2
示例#11
0
def test_foreground_warning():
    # ask to train only with foreground patches when there are none
    # include a constant label image that must trigger a warning
    conf = Config3D(
        n_rays=32,
        train_patch_size=(16, 32, 16),
        train_foreground_only=1,
        train_steps_per_epoch=1,
        train_epochs=1,
        train_batch_size=2,
    )
    X, Y = np.ones((2, 32, 48, 16), np.float32), np.ones((2, 32, 48, 16),
                                                         np.uint16)

    with pytest.warns(UserWarning):
        StarDist3D(conf, None, None).train(X,
                                           Y,
                                           validation_data=(X[-1:], Y[-1:]))
示例#12
0
def test_optimize_thresholds():
    model_path = path_model3d()
    model = StarDist3D(None, name=model_path.name, basedir=str(model_path.parent))
    img, mask = real_image3d()
    x = normalize(img,1,99.8)
    def _opt(model):
        return model.optimize_thresholds([x],[mask],
                                    nms_threshs = [.3,.5],
                                    iou_threshs = [.3,.5],
                                    optimize_kwargs = dict(tol=1e-1),
                                    save_to_json = False)

    t1 = _opt(model)
    # enforce implicit tiling 
    model.config.train_batch_size = 1
    model.config.train_patch_size = tuple(s-1 for s in x.shape)
    t2 = _opt(model)
    assert all(np.allclose(t1[k],t2[k]) for k in t1.keys())         
    return model
示例#13
0
def test_model(n_rays, grid):
    img = circle_image(shape = (64,80,96))
    imgs = np.repeat(img[np.newaxis],10, axis = 0)

    X = imgs+.6*np.random.uniform(0,1,imgs.shape)
    Y = imgs.astype(np.uint16)

    conf = Config3D (
        rays       = n_rays,
        grid       = grid,
        use_gpu    = False,
        train_epochs     = 1,
        train_steps_per_epoch = 2,
        train_loss_weights = (4,1),
        train_patch_size = (48,64,64),
        n_channel_in = 1)

    with tempfile.TemporaryDirectory() as tmp:
        model = StarDist3D(conf, name='stardist', basedir=tmp)
        model.train(X, Y, validation_data=(X[:2],Y[:2]))
示例#14
0
def train_model(x_train,
                y_train,
                x_val,
                y_val,
                save_path,
                patch_size,
                anisotropy,
                n_rays=96):

    rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)
    # make the model config
    # copied from the stardist training notebook, this is a very weird line ...
    use_gpu = False and gputools_available()
    # predict on subsampled image for increased efficiency
    grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)
    config = Config3D(rays=rays,
                      grid=grid,
                      use_gpu=use_gpu,
                      n_channel_in=1,
                      train_patch_size=patch_size,
                      anisotropy=anisotropy)

    if use_gpu:
        print("Using a GPU for training")
        # limit gpu memory
        from csbdeep.utils.tf import limit_gpu_memory
        limit_gpu_memory(0.8)
    else:
        print("GPU not found, using the CPU for training")

    save_root, save_name = os.path.split(save_path)
    os.makedirs(save_root, exist_ok=True)
    model = StarDist3D(config, name=save_name, basedir=save_root)

    model.train(x_train,
                y_train,
                validation_data=(x_val, y_val),
                augmenter=augmenter)
    optimal_parameters = model.optimize_thresholds(x_val, y_val)
    return optimal_parameters
示例#15
0
def test_model(tmpdir, n_rays, grid, n_channel, backbone, workers,
               use_sequence):
    img = circle_image(shape=(64, 80, 96))
    imgs = np.repeat(img[np.newaxis], 3, axis=0)

    if n_channel is not None:
        imgs = np.repeat(imgs[..., np.newaxis], n_channel, axis=-1)
    else:
        n_channel = 1

    X = imgs + .6 * np.random.uniform(0, 1, imgs.shape)
    Y = (imgs if imgs.ndim == 4 else imgs[..., 0]).astype(int)

    if use_sequence:
        X, Y = NumpySequence(X), NumpySequence(Y)

    conf = Config3D(
        backbone=backbone,
        rays=n_rays,
        grid=grid,
        n_channel_in=n_channel,
        use_gpu=False,
        train_epochs=1,
        train_steps_per_epoch=1,
        train_batch_size=2,
        train_loss_weights=(4, 1),
        train_patch_size=(48, 64, 32),
        train_sample_cache=not use_sequence,
    )

    model = StarDist3D(conf, name='stardist', basedir=str(tmpdir))
    model.train(X, Y, validation_data=(X[:2], Y[:2]), workers=workers)
    ref = model.predict(X[0])
    res = model.predict(X[0],
                        n_tiles=((1, 2, 3) if X[0].ndim == 3 else
                                 (1, 2, 3, 1)))

    # deactivate as order of labels might not be the same
    # assert all(np.allclose(u,v) for u,v in zip(ref,res))
    return model
示例#16
0
def run_prediction(image_files, model_path, root, prediction_folder):

    # load the model
    model_root, model_name = os.path.split(model_path.rstrip('/'))
    model = StarDist3D(None, name=model_name, basedir=model_root)

    res_folder = os.path.join(root, prediction_folder)
    os.makedirs(res_folder, exist_ok=True)

    # normalization parameters: lower and upper percentile used for image normalization
    # maybe these should be exposed
    lower_percentile = 1
    upper_percentile = 99.8
    ax_norm = (0, 1, 2)

    for im_file in tqdm(image_files, desc="run stardist prediction"):
        im = imageio.volread(im_file)
        im = normalize(im, lower_percentile, upper_percentile, axis=ax_norm)
        pred, _ = model.predict_instances(im)

        im_name = os.path.split(im_file)[1]
        save_path = os.path.join(res_folder, im_name)
        imageio.imsave(save_path, pred)
示例#17
0
    def Train(self):

        BinaryName = 'BinaryMask/'
        RealName = 'RealMask/'
        Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif'))
        Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True)
        Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True)
        RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif'))
        ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif'))
        ValRealMask = sorted(
            glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif'))

        print('Instance segmentation masks:', len(RealMask))
        if len(RealMask) == 0:

            print('Making labels')
            Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))

            for fname in Mask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = label(image)

                imwrite((self.BaseDir + '/' + RealName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))
        print('Semantic segmentation masks:', len(Mask))
        if len(Mask) == 0:
            print('Generating Binary images')

            RealfilesMask = sorted(
                glob.glob(self.BaseDir + '/' + RealName + '*tif'))

            for fname in RealfilesMask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = image > 0

                imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        if self.GenerateNPZ:

            raw_data = RawData.from_folder(
                basepath=self.BaseDir,
                source_dirs=['Raw/'],
                target_dir='BinaryMask/',
                axes='ZYX',
            )

            X, Y, XY_axes = create_patches(
                raw_data=raw_data,
                patch_size=(self.PatchZ, self.PatchY, self.PatchX),
                n_patches_per_image=self.n_patches_per_image,
                save_file=self.BaseDir + self.NPZfilename + '.npz',
            )

        # Training UNET model
        if self.TrainUNET:
            print('Training UNET model')
            load_path = self.BaseDir + self.NPZfilename + '.npz'

            (X, Y), (X_val,
                     Y_val), axes = load_training_data(load_path,
                                                       validation_split=0.1,
                                                       verbose=True)
            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_n_depth=self.depth,
                            train_epochs=self.epochs,
                            train_batch_size=self.batch_size,
                            unet_n_first=self.startfilter,
                            train_loss='mse',
                            unet_kern_size=self.kern_size,
                            train_learning_rate=self.learning_rate,
                            train_reduce_lr={
                                'patience': 5,
                                'factor': 0.5
                            })
            print(config)
            vars(config)

            model = CARE(config,
                         name='UNET' + self.model_name,
                         basedir=self.model_dir)

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + 'UNET' +
                                  self.copy_model_name + '/' +
                                  'weights_now.h5') and os.path.exists(
                                      self.model_dir + 'UNET' +
                                      self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    model.load_weights(self.copy_model_dir + 'UNET' +
                                       self.copy_model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_last.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_best.h5')

            history = model.train(X, Y, validation_data=(X_val, Y_val))

            print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

        if self.TrainSTAR:
            print('Training StarDistModel model with', self.backbone,
                  'backbone')
            self.axis_norm = (0, 1, 2)
            if self.CroppedLoad == False:
                assert len(Raw) > 1, "not enough training data"
                print(len(Raw))
                rng = np.random.RandomState(42)
                ind = rng.permutation(len(Raw))

                X_train = list(map(ReadFloat, Raw))
                Y_train = list(map(ReadInt, RealMask))
                self.Y = [
                    label(DownsampleData(y, self.DownsampleFactor))
                    for y in tqdm(Y_train)
                ]
                self.X = [
                    normalize(DownsampleData(x, self.DownsampleFactor),
                              1,
                              99.8,
                              axis=self.axis_norm) for x in tqdm(X_train)
                ]
                n_val = max(1, int(round(0.15 * len(ind))))
                ind_train, ind_val = ind[:-n_val], ind[-n_val:]

                self.X_val, self.Y_val = [self.X[i] for i in ind_val
                                          ], [self.Y[i] for i in ind_val]
                self.X_trn, self.Y_trn = [self.X[i] for i in ind_train
                                          ], [self.Y[i] for i in ind_train]

                print('number of images: %3d' % len(self.X))
                print('- training:       %3d' % len(self.X_trn))
                print('- validation:     %3d' % len(self.X_val))

            if self.CroppedLoad:
                self.X_trn = self.DataSequencer(Raw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_trn = self.DataSequencer(RealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)

                self.X_val = self.DataSequencer(ValRaw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_val = self.DataSequencer(ValRealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)
                self.train_sample_cache = False

            print(Config3D.__doc__)

            anisotropy = (1, 1, 1)
            rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy)

            if self.backbone == 'resnet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    resnet_n_blocks=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    resnet_kernel_size=(self.kern_size, self.kern_size,
                                        self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    resnet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1)

            if self.backbone == 'unet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    unet_n_depth=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    unet_kernel_size=(self.kern_size, self.kern_size,
                                      self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    unet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1,
                    train_sample_cache=False)

            print(conf)
            vars(conf)

            Starmodel = StarDist3D(conf,
                                   name=self.model_name,
                                   basedir=self.model_dir)
            print(
                Starmodel._axes_tile_overlap('ZYX'),
                os.path.exists(self.model_dir + self.model_name + '/' +
                               'weights_now.h5'))

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_now.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_now.h5')
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_last.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_last.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_last.h5')

                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_best.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_best.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_best.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_last.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_best.h5')

            historyStar = Starmodel.train(self.X_trn,
                                          self.Y_trn,
                                          validation_data=(self.X_val,
                                                           self.Y_val),
                                          epochs=self.epochs)
            print(sorted(list(historyStar.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(historyStar, ['loss', 'val_loss'], [
                'dist_relevant_mae', 'val_dist_relevant_mae',
                'dist_relevant_mse', 'val_dist_relevant_mse'
            ])
示例#18
0
def run_StarDist(
        inputImagePath, z_count, t_count, model_selection,
        probThreshold, nmsThreshold, normalizationLow, normalizationHigh,
        output_type, resultPath):

    print('------------------------------------------')
    print('       StarDist Virtual Environment')
    print('------------------------------------------')
    print(f'   inputImagePath = {inputImagePath}')
    print(f'          z_count = {z_count}')
    print(f'          t_count = {t_count}')
    print(f'  model_selection = {model_selection}')
    print(f'    probThreshold = {probThreshold}')
    print(f'     nmsThreshold = {nmsThreshold}')
    print(f' normalizationLow = {normalizationLow}')
    print(f'normalizationHigh = {normalizationHigh}')
    print(f'       outputType = {output_type}')
    print(f'       resultPath = {resultPath}')

    # Limit GPU memory usage
    limit_gpu_memory(fraction=None, allow_growth=True)

    # Get the path of the folder that contains this python script
    script_folder = pathlib.Path(__file__).resolve().parent
    logger.info(f'Script Folder = {script_folder}')

    # Get the model selections
    model_dict = {0: '2D_demo', 1: '2D_fluor_nuc',
                  2: '2D_dsb2018', 3: '3D_demo'}
    if model_selection not in model_dict:
        logger.warn('Selection is not available, use 2D_demo instead')
    model_name = model_dict[model_selection]

    # Load StarDist model assuming the 3D_demo and 2D_demo folders
    # are both in `script_folder`
    if z_count > 1:
        # input image is 3D or 3D+T
        logger.warn('Input is a 3D/3D+T image, use 3D_demo model')
        # Use 3D model for 3D image
        model = StarDist3D(None, name='3D_demo', basedir=script_folder)
        # Set 3D block size
        tile_shape = (50, 256, 256)
        # Check if input is a time-lapse image
        if t_count > 1:
            axes = 'YXZT'
        else:
            axes = 'YXZ'
    elif z_count == 1:
        # Use 2D model for 2D image
        model = StarDist2D(None, name=model_name, basedir=script_folder)
        # Set 2D tile size
        tile_shape = (512, 512)
        # Check if input is a time-lapse image
        if t_count > 1:
            axes = 'TYX'
        else:
            axes = 'YX'
    else:
        raise ValueError('Z count must be positive')

    # Load input image
    image = tifffile.imread(inputImagePath)
    dtype = image.dtype

    # Current limitation: input and output should have the same depth
    if dtype == np.uint8:
        logger.warn('Label image will be saved in 8bit')

    # Not a time-lapse
    if t_count == 1:
        image = image[np.newaxis]

    # Create output labeled image
    labels = np.empty_like(image, dtype=dtype)
    n_tiles = [i // t + 1 for t, i in zip(tile_shape, image[0].shape)]

    # Get thresholds
    prob_thresh = np.clip(probThreshold, 0.0, 1.0)
    nms_thresh = np.clip(nmsThreshold, 0.0, 1.0)

    # Use default thresholds optimized for the StarDist model when both
    # thresholds are set as 0
    if prob_thresh == 0.0 and nms_thresh == 0.0:
        logger.warn(
            'Use default thresholds of the StarDist model when both '
            'thresholds are set as 0.')
        prob_thresh = nms_thresh = None

    logger.info(f'probThreshold = {prob_thresh}, nmsThreshold = {nms_thresh}')

    # Get Normalization Percentile
    p_min = np.clip(normalizationLow, 0.0, 100.0)
    p_max = np.clip(normalizationHigh, 0.0, 100.0)

    # Use default normalization for the StarDist model when p_min >= p_max
    if p_min >= p_max:
        logger.warn(
            'Use default normalization of the StarDist model '
            'when p_min >= p_max.')
        p_min, p_max = 2, 99.9

    logger.info(f'normalizationLow = {p_min}, normalizationHigh = {p_max}')

    # Apply StarDist model
    for t in range(t_count):
        labels[t] = model.predict_instances(
            normalize(image[t], p_min=p_min, p_max=p_max),
            prob_thresh=prob_thresh,
            nms_thresh=nms_thresh,
            n_tiles=n_tiles,
            show_tile_progress=False)[0].astype(dtype)

        # Convert labeled mask to binary mask
        if (output_type == 1):
            # Add two pixel gap between neighboring masks
            if (z_count > 1):
                addOnePixelGap_3D(labels[t])
            else:
                addOnePixelGap_2D(labels[t])

            # Convert to binary mask
            val = 255
            if (labels[t].dtype.type == np.uint16):
                val = 65535
            labels[t] = (labels[t] > 0) * val

    # Not a time-lapse
    if t_count == 1:
        labels = labels[0]

    # Save the labeled image
    tifffile.imwrite(resultPath,
                     labels,
                     photometric='minisblack',
                     metadata={'axes': axes})
示例#19
0
    def __len__(self):
        return self.n


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='')
    parser.add_argument("-n", type=int, default=10)
    parser.add_argument("-s", "--size", type=int, default=256)
    parser.add_argument("--steps", type=int, default=20)
    parser.add_argument("--nocache", action='store_true')

    args = parser.parse_args()

    X, Y = LargeSequence(20, size=args.size), LargeSequence(20, size=args.size)

    conf = Config3D(n_rays=32,
                    backbone="unet",
                    unet_n_depth=1,
                    train_epochs=args.n,
                    train_steps_per_epoch=args.steps,
                    train_batch_size=1,
                    train_patch_size=(min(96, args.size), ) * 3,
                    train_sample_cache=not args.nocache)

    model = StarDist3D(conf, None, None)
    model.prepare_for_training()
    model.callbacks.append(
        LambdaCallback(on_epoch_end=lambda a, b: print_memory()))
    model.train(X, Y, validation_data=(X[0][np.newaxis], Y[0][np.newaxis]))
    def run(self, workspace):
        images = workspace.image_set
        x = images.get_image(self.x_name.value)
        dimensions = x.dimensions
        x_data = x.pixel_data

        # Validate some settings
        if self.model.value in (GREY_1, GREY_2) and x.multichannel:
            raise ValueError(
                "Color images are not supported by this model. Please provide greyscale images."
            )
        elif self.model.value == COLOR_1 and not x.multichannel:
            raise ValueError(
                "Greyscale images are not supported by this model. Please provide a color overlay."
            )

        if self.model.value != MODEL_CUSTOM:
            if x.volumetric:
                raise NotImplementedError(
                    "StarDist's inbuilt models do not currently support 3D images"
                )
            model = StarDist2D.from_pretrained(self.model.value)
        else:
            model_directory, model_name = os.path.split(
                self.model_directory.get_absolute_path())
            if x.volumetric:
                from stardist.models import StarDist3D
                model = StarDist3D(config=None,
                                   basedir=model_directory,
                                   name=model_name)
            else:
                model = StarDist2D(config=None,
                                   basedir=model_directory,
                                   name=model_name)

        tiles = None
        if self.tile_image.value:
            tiles = []
            if x.volumetric:
                tiles += [1]
            tiles += [self.n_tiles_x.value, self.n_tiles_y.value]
            # Handle colour channels
            tiles += [1] * max(0, x.pixel_data.ndim - len(tiles))
            print(x.pixel_data.shape, x.pixel_data.ndim, tiles)

        if not self.save_probabilities.value:
            # Probabilities aren't wanted, things are simple
            data = model.predict_instances(normalize(x.pixel_data),
                                           return_predict=False,
                                           n_tiles=tiles)
            y_data = data[0]
        else:
            data, probs = model.predict_instances(normalize(x.pixel_data),
                                                  return_predict=True,
                                                  sparse=False,
                                                  n_tiles=tiles)
            y_data = data[0]

            # Scores aren't at the same resolution as the input image.
            # We need to slightly resize to match the original image.
            size_corrected = resize(probs[0], y_data.shape)
            prob_image = Image(
                size_corrected,
                parent_image=x.parent_image,
                convert=False,
                dimensions=len(size_corrected.shape),
            )

            workspace.image_set.add(self.probabilities_name.value, prob_image)

            if self.show_window:
                workspace.display_data.probabilities = size_corrected

        y = Objects()
        y.segmented = y_data
        y.parent_image = x.parent_image
        objects = workspace.object_set
        objects.add_objects(y, self.y_name.value)

        self.add_measurements(workspace)

        if self.show_window:
            workspace.display_data.x_data = x_data
            workspace.display_data.y_data = y_data
            workspace.display_data.dimensions = dimensions
示例#21
0
from vollseg.OptimizeThreshold import OptimizeThreshold
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
from pathlib import Path

# In[2]:

BaseDir = '/data/u934/service_imagerie/v_kapoor/CurieTrainingDatasets/MouseClaudia/AugmentedGreenCell3D/'

Model_Dir = '/data/u934/service_imagerie/v_kapoor/CurieDeepLearningModels/MouseClaudia/'
SaveDir = '/data/u934/service_imagerie/v_kapoor/CurieTrainingDatasets/MouseClaudia/'

StardistModelName = 'ScipyDeepGreenCells'
UNETModelName = 'UNETScipyDeepGreenCells'

NoiseModel = None
Starmodel = StarDist3D(config=None, name=StardistModelName, basedir=Model_Dir)
UnetModel = CARE(config=None, name=UNETModelName, basedir=Model_Dir)

# In[3]:

#Number of tiles to break the image into for applying the prediction to fit in the computer memory
n_tiles = (1, 2, 2)

#Use Probability map = True or distance map = False as the image to perform watershed on
UseProbability = False

# In[ ]:

Raw = sorted(glob.glob(BaseDir + '/Raw/' + '*.tif'))
RealMask = sorted(glob.glob(BaseDir + '/RealMask/' + '*.tif'))
X = list(map(imread, Raw))
示例#22
0
    parser.add_argument('-c', '--channel', type=str, help="channel")
    parser.add_argument('-s', '--scale', type=str, help="scale")

    args = parser.parse_args()

    print("reading ...", args.input, args.channel + '/' + args.scale)
    im = z5py.File(args.input, use_zarr_format=False)
    img = im[args.channel + '/' + args.scale][:, :, :]

    n_tiles = tuple(int(np.ceil(s / 128)) for s in img.shape)
    print("estimated tiling:", n_tiles)

    print("normalizing input...")
    img_normed = normalize(img, 4, 99.8)

    model = StarDist3D(None, name=args.model, basedir=args.model)

    print("predicting...")
    # the affinity based labels
    label_starfinity, res_dict = model.predict_instances(img_normed,
                                                         n_tiles=n_tiles,
                                                         affinity=True,
                                                         affinity_thresh=0.1,
                                                         verbose=True)

    # the normal stardist labels are implicitly calculated and
    # can be accessed from the results dict
    label_stardist = res_dict["markers"]

    print("saving...")
 def get_model(self) -> StarDistBase:
     return StarDist3D.from_pretrained(self.parameters["model"])
示例#24
0
def _test_model_multiclass(n_classes=1,
                           classes="auto",
                           n_channel=None,
                           basedir=None,
                           epochs=20,
                           batch_size=1):
    from skimage.measure import regionprops

    img, mask = real_image3d()
    img = normalize(img, 1, 99.8)

    if n_channel is not None:
        img = np.repeat(img[..., np.newaxis], n_channel, axis=-1)
    else:
        n_channel = 1

    X, Y = [img, img, img], [mask, mask, mask]

    conf = Config3D(
        n_rays=32,
        grid=(2, 1, 2),
        n_channel_in=n_channel,
        n_classes=n_classes,
        use_gpu=False,
        train_epochs=1,
        train_steps_per_epoch=10,
        train_batch_size=batch_size,
        train_loss_weights=(1., .2) if n_classes is None else (1, .2, 1.),
        train_patch_size=(24, 32, 32),
    )

    # efine some classes according to the areas
    if n_classes is not None and n_classes > 1 and classes == "auto":
        regs = regionprops(mask)
        areas = tuple(r.area for r in regs)
        inds = np.argsort(areas)
        ss = tuple(
            slice(n * len(regs) // n_classes, (n + 1) * len(regs) // n_classes)
            for n in range(n_classes))
        classes = {}
        for i, s in enumerate(ss):
            for j in inds[s]:
                classes[regs[j].label] = i + 1
        classes = (classes, ) * len(X)

    model = StarDist3D(conf,
                       name=None if basedir is None else "stardist",
                       basedir=str(basedir))

    val_classes = {k: 1 for k in set(mask[mask > 0])}

    s = model.train(X,
                    Y,
                    classes=classes,
                    epochs=epochs,
                    validation_data=(X[:1], Y[:1]) if n_classes is None else
                    (X[:1], Y[:1], (val_classes, )))

    labels, res = model.predict_instances(img)
    # return  model, X,Y, classes, labels, res

    img = np.tile(img, (4, 2, 2) if img.ndim == 3 else (4, 2, 2, 1))

    kwargs = dict(prob_thresh=.5)
    labels1, res1 = model.predict_instances(img, **kwargs)
    labels2, res2 = model.predict_instances(img, sparse=True, **kwargs)
    labels3, res3 = model.predict_instances_big(
        img,
        axes="ZYX" if img.ndim == 3 else "ZYXC",
        block_size=96,
        min_overlap=8,
        context=8,
        **kwargs)

    assert np.allclose(labels1, labels2)
    assert all([
        np.allclose(res1[k], res2[k])
        for k in set(res1.keys()).union(set(res2.keys()))
        if isinstance(res1[k], np.ndarray)
    ])

    return model, img, res1, res2, res3
示例#25
0
 def _parse(n_classes, classes):
     model = StarDist3D(Config3D(n_classes=n_classes), None, None)
     classes = model._parse_classes_arg(classes, length=1)
     return classes