Exemplo n.º 1
0
def test_model_train():
    rng = np.random.RandomState(42)
    configs = config_generator(
        axes=['YX', 'ZYX'],
        n_channel_in=[1, 2],
        n_channel_out=[1, 2],
        probabilistic=[False, True],
        # unet_residual         = [False,True],
        unet_n_depth=[1],
        unet_kern_size=[3],
        unet_n_first=[4],
        unet_last_activation=['linear'],
        # unet_input_shape      = [(None, None, 1)],
        train_loss=['mae', 'laplace'],
        train_epochs=[2],
        train_steps_per_epoch=[2],
        # train_learning_rate   = [0.0004],
        train_batch_size=[2],
        # train_tensorboard     = [True],
        # train_checkpoint      = ['weights_best.h5'],
        # train_reduce_lr       = [{'factor': 0.5, 'patience': 10}],
    )
    with tempfile.TemporaryDirectory() as tmpdir:
        for config in configs:
            K.clear_session()
            if config.is_valid():
                X = rng.uniform(size=(4, ) + (32, ) * config.n_dim +
                                (config.n_channel_in, ))
                Y = rng.uniform(size=(4, ) + (32, ) * config.n_dim +
                                (config.n_channel_out, ))
                model = CARE(config, basedir=tmpdir)
                model.train(X, Y, (X, Y))
Exemplo n.º 2
0
def train_care_generated_data(model: CARE,
                              epochs: int,
                              X: np.ndarray,
                              Y: np.ndarray,
                              X_val: np.ndarray,
                              Y_val: np.ndarray,
                              steps_per_epoch: int = 400) -> None:
    """
    Train a CARE model on a dataset

    :param model: The CARE model to train
    :param epochs: The number of epochs to train for
    :param X: The training data input
    :param Y: The training data expected output
    :param X_val: The validation data input
    :param Y_val: The validation data expected output
    """
    rearr = lambda arr: np.moveaxis(arr, [0, 1, 2, 3, 4], [0, 4, 1, 2, 3])
    rearry = lambda arr: np.moveaxis(arr, [0, 1, 2, 3, 4], [0, 4, 1, 2, 3]
                                     )[:, :, :, :, :1]
    X = rearr(X)
    X_val = rearr(X_val)
    Y = rearry(Y)
    Y_val = rearry(Y_val)

    model.train(X,
                Y,
                validation_data=(X_val, Y_val),
                epochs=epochs,
                steps_per_epoch=steps_per_epoch)
Exemplo n.º 3
0
def test_model_train(tmpdir,config):
    rng = np.random.RandomState(42)
    K.clear_session()
    X = rng.uniform(size=(4,)+(32,)*config.n_dim+(config.n_channel_in,))
    Y = rng.uniform(size=(4,)+(32,)*config.n_dim+(config.n_channel_out,))
    model = CARE(config,basedir=str(tmpdir))
    model.train(X,Y,(X,Y))
Exemplo n.º 4
0
def run_z(y, x, models):
    # Prep data
    # model prediction and normalizations output float64
    x = np.array(x, dtype='float64')
    y = np.array(y, dtype='float64')
    axes = 'ZYX'

    # Prepare the dict for metrics
    d = {
        'output': [],
        'columns': ['output', 'rmse', 'ssim'],
        'id_vars': ['output'],
        'var_name': 'metric'
    }

    # Define comparisons
    def get_output(name, y, x):
        return [name, np.sqrt(mse(x, y)), ssim(x, y)]

    # Normalize GT dynamic range to enable comparisons w/ numbers
    yn = util.percentile_norm(y, axes)

    # Get the comparison for normalized input im
    xn = util.percentile_norm(x, axes)
    d['output'].append(get_output('input', yn, xn))

    predictions = {'models': ['input', 'N(GT)'], 'ims': [xn, yn]}
    for m in models:
        model = CARE(config=None, name=m, basedir='models')
        restored = model.predict_probabilistic(x, axes, n_tiles=(1, 4, 4))
        pred = restored.mean()
        y_pred_n = util.percentile_norm(pred, axes)

        d['output'].append(get_output(m, yn, y_pred_n))
        predictions['models'].append(m)
        predictions['ims'].append(y_pred_n)

    # Plot a random stack
    zix = util.get_randint_ixs(1, len(y))
    ims = [[im[zix] for im in predictions['ims']]]
    plt.figure(figsize=(16, 10))
    plot_some(np.stack(ims), title_list=[predictions['models']])
    plt.show()

    # Costruct a df for the barplot
    df = pd.DataFrame(
        d['output'],
        columns=d['columns'],
    )
    df = pd.melt(df, id_vars=d['id_vars'], var_name=d['var_name'])

    g = sns.catplot(x='metric',
                    y='value',
                    hue='output',
                    kind='bar',
                    sharey=False,
                    data=df)
    g.ax.set_ylim(0, 1)
    plt.show()
Exemplo n.º 5
0
    def train(self, channels=None, **config_args):
        #limit_gpu_memory(fraction=1)
        if channels is None:
            channels = self.train_channels

        for ch in channels:
            print("-- Training channel {}...".format(ch))
            (X, Y), (X_val, Y_val), axes = load_training_data(
                self.get_training_patch_path() /
                'CH_{}_training_patches.npz'.format(ch),
                validation_split=0.1,
                verbose=False)

            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            train_epochs=self.train_epochs,
                            train_steps_per_epoch=self.train_steps_per_epoch,
                            train_batch_size=self.train_batch_size,
                            **config_args)
            # Training
            model = CARE(config,
                         'CH_{}_model'.format(ch),
                         basedir=pathlib.Path(self.out_dir) / 'models')

            # Show learning curve and example validation results
            try:
                history = model.train(X, Y, validation_data=(X_val, Y_val))
            except tf.errors.ResourceExhaustedError:
                print(
                    "ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?"
                )
                return

            #print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

            plt.figure(figsize=(12, 7))
            _P = model.keras_model.predict(X_val[:5])

            plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
            plt.suptitle('5 example validation patches\n'
                         'top row: input (source),  '
                         'middle row: target (ground truth),  '
                         'bottom row: predicted from source')

            plt.show()

            print("-- Export model for use in Fiji...")
            model.export_TF()
            print("Done")
Exemplo n.º 6
0
def gen_care_single_model(name: str) -> CARE:
    """
    Generate a single channel CARE model or retrieve the one that already exists under the name specified

    :param name: The name of the model
    :return: The CARE model
    """
    try:
        model = CARE(None, name)
    except FileNotFoundError:
        config = Config('xyzc', n_channel_in=1, n_channel_out=1)
        model = CARE(config, name)

    return model
Exemplo n.º 7
0
def dev(args):
    import json

    # Load and parse training data
    (X,
     Y), (X_val,
          Y_val), axes = load_training_data(args.train_data,
                                            validation_split=args.valid_split,
                                            axes=args.axes,
                                            verbose=True)

    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

    # Model config
    print('args.resume: ', args.resume)
    if args.resume:
        # If resuming, config=None will reload the saved config
        config = None
        print('Attempting to resume')
    elif args.config:
        print('loading config from args')
        config_args = json.load(open(args.config))
        config = Config(**config_args)
    else:
        config = Config(axes,
                        n_channel_in,
                        n_channel_out,
                        probabilistic=args.prob,
                        train_steps_per_epoch=args.steps,
                        train_epochs=args.epochs)
        print(vars(config))

    # Load or init model
    model = CARE(config, args.model_name, basedir='models')

    # Training, tensorboard available
    history = model.train(X, Y, validation_data=(X_val, Y_val))

    # Plot training results
    print(sorted(list(history.history.keys())))
    plt.figure(figsize=(16, 5))
    plot_history(history, ['loss', 'val_loss'],
                 ['mse', 'val_mse', 'mae', 'val_mae'])
    plt.savefig(args.model_name + '_training.png')

    # Export model to be used w/ csbdeep fiji plugins and KNIME flows
    model.export_TF()
Exemplo n.º 8
0
def load_data_and_model(npz_data, model_name, valid_split=0.2, axes='SCZXY'):

    model = CARE(config=None, name=model_name, basedir='models')
    X_val, Y_val = load_training_data(npz_data,
                                      validation_split=valid_split,
                                      axes=axes,
                                      verbose=True)[1]
    return X_val, Y_val, model
Exemplo n.º 9
0
def test_model_predict_tiled(tmpdir,config):
    """
    Test that tiled prediction yields the same
    or similar result as compared to predicting
    the whole image at once.
    """
    rng = np.random.RandomState(42)
    normalizer, resizer = NoNormalizer(), NoResizer()

    K.clear_session()
    model = CARE(config,basedir=str(tmpdir))

    def _predict(imdims,axes,n_tiles):
        img = rng.uniform(size=imdims)
        # print(img.shape, axes)
        mean,       scale       = model._predict_mean_and_scale(img, axes, normalizer, resizer, n_tiles=None)
        mean_tiled, scale_tiled = model._predict_mean_and_scale(img, axes, normalizer, resizer, n_tiles=n_tiles)
        assert mean.shape == mean_tiled.shape
        if config.probabilistic:
            assert scale.shape == scale_tiled.shape
        error_max = np.max(np.abs(mean-mean_tiled))
        # print('n, k, err = {0}, {1}x{1}, {2}'.format(model.config.unet_n_depth, model.config.unet_kern_size, error_max))
        assert error_max < 1e-3
        return mean, mean_tiled

    imdims = list(rng.randint(50,70,size=config.n_dim))
    if config.n_dim == 3:
        imdims[0] = 16 # make one dim small, otherwise test takes too long
    div_n = 2**config.unet_n_depth
    imdims = [(d//div_n)*div_n for d in imdims]

    imdims.insert(0,config.n_channel_in)
    axes = 'C'+config.axes.replace('C','')

    for n_tiles in (
        -1, 1.2,
        [1]+[1.2]*config.n_dim,
        [1]*config.n_dim, # missing value for channel axis
        [2]+[1]*config.n_dim, # >1 tiles for channel axis
    ):
        with pytest.raises(ValueError):
            _predict(imdims,axes,n_tiles)

    for n_tiles in [list(rng.randint(1,5,size=config.n_dim)) for _ in range(3)]:
        # print(imdims,axes,[1]+n_tiles)
        if config.n_channel_in == 1:
            _predict(imdims[1:],axes[1:],n_tiles)
        _predict(imdims,axes,[1]+n_tiles)

    # legacy api: tile only largest dimension
    n_blocks = np.max(imdims) // div_n
    for n_tiles in (2,5,n_blocks+1):
        with pytest.warns(UserWarning):
            if config.n_channel_in == 1:
                _predict(imdims[1:],axes[1:],n_tiles)
            _predict(imdims,axes,n_tiles)
Exemplo n.º 10
0
def test_model_predict():
    rng = np.random.RandomState(42)
    configs = config_generator(
        axes=['YX', 'ZYX'],
        n_channel_in=[1, 2],
        n_channel_out=[1, 2],
        probabilistic=[False, True],
        # unet_residual         = [False,True],
        unet_n_depth=[2],
        unet_kern_size=[3],
        unet_n_first=[4],
        unet_last_activation=['linear'],
        # unet_input_shape      = [(None, None, 1)],
    )
    with tempfile.TemporaryDirectory() as tmpdir:
        normalizer, resizer = NoNormalizer(), NoResizer()

        for config in filter(lambda c: c.is_valid(), configs):
            K.clear_session()
            model = CARE(config, basedir=tmpdir)
            axes = config.axes

            def _predict(imdims, axes):
                img = rng.uniform(size=imdims)
                # print(img.shape, axes, config.n_channel_out)
                mean, scale = model._predict_mean_and_scale(
                    img, axes, normalizer, resizer)
                if config.probabilistic:
                    assert mean.shape == scale.shape
                else:
                    assert scale is None

                if 'C' not in axes:
                    if config.n_channel_out == 1:
                        assert mean.shape == img.shape
                    else:
                        assert mean.shape == img.shape + (
                            config.n_channel_out, )
                else:
                    channel = axes_dict(axes)['C']
                    imdims[channel] = config.n_channel_out
                    assert mean.shape == tuple(imdims)

            imdims = list(rng.randint(20, 40, size=config.n_dim))
            div_n = 2**config.unet_n_depth
            imdims = [(d // div_n) * div_n for d in imdims]

            if config.n_channel_in == 1:
                _predict(imdims, axes=axes.replace('C', ''))

            channel = rng.randint(0, config.n_dim)
            imdims.insert(channel, config.n_channel_in)
            _axes = axes.replace('C', '')
            _axes = _axes[:channel] + 'C' + _axes[channel:]
            _predict(imdims, axes=_axes)
Exemplo n.º 11
0
def gen_care_dual_model(name: str, batch_size: int = 16, **kwargs):
    """
    Generate a dual channel CARE model or retrieve the one that already exists under the name specified

    :param name: The name of the model
    :param batch_size: The training batch size to use (only used if the model doesn't exist yet)
    :param kwargs: Parameters to pass to the model constructor (only used if the model doesn't exist yet
    :return: The CARE model
    """
    try:
        model = CARE(None, name)
    except FileNotFoundError:
        config = Config('xyzc',
                        n_channel_in=2,
                        n_channel_out=1,
                        train_batch_size=batch_size,
                        **kwargs)
        model = CARE(config, name)

    return model
Exemplo n.º 12
0
def _run_multi(y, x, models, d, ix, args):
    """ Trying an alternative normalization to see if that's causing the
    incong. with the paper results
    """
    # Prep data
    axes = 'ZYX'

    def get_output(name, y, x):
        # Define comparisons
        return [name, np.sqrt(mse(x, y)), ssim(x, y)]

    # Normalize GT dynamic range to enable comparisons w/ numbers
    yn = normalize(y, pmin=0.1, pmax=99.9)

    # Get the comparison for normalized input im
    if args.rescale_x:
        xn = normalize_minmse(x, yn)
    else:
        xn = normalize(x)
    d['output'].append(get_output('input', yn, xn))

    for m in models:
        model = CARE(config=None, name=m, basedir='models')
        pred = model.predict(xn, axes, n_tiles=(1, 4, 4), normalizer=None)
        pred = normalize_minmse(pred, yn)
        d['output'].append(get_output(m, yn, pred))

        # Save the first test volume for each model
        if args.save_predictions:
            if ix == 0:
                save_prediction(pred, m, d['condition'])

    # Save the first test volume input and GT
    if args.save_predictions:
        if ix == 0:
            save_prediction(xn, 'input', d['condition'])
            save_prediction(yn, None, d['condition'])

    return d
Exemplo n.º 13
0
def _run_multi(y, x, models, d):
    # Prep data
    axes = 'ZYX'

    # Define comparisons
    def get_output(name, y, x):
        return [name, np.sqrt(mse(x, y)), ssim(x, y)]

    # Normalize GT dynamic range to enable comparisons w/ numbers
    yn = normalize(y, pmin=0.1, pmax=99.9)

    # Get the comparison for normalized input im
    xn = normalize(x)
    d['output'].append(get_output('input', yn, xn))

    for m in models:
        model = CARE(config=None, name=m, basedir='models')
        # None normalizer, already normalizing x, and this way can report the
        # exact pmin/max params used
        pred = model.predict(xn, n_tiles=(1, 4, 4), axes=axes, normalizer=None)
        pred = normalize_minmse(pred, yn)
        d['output'].append(get_output(m, yn, pred))

    return d
Exemplo n.º 14
0
def tst_care(care_model_path: Path,
             lfd_path: Path,
             overwrite: Optional[bool] = False):
    assert lfd_path.name == "pred"
    assert lfd_path.parent.name == "lfd"
    from csbdeep.models import CARE

    axes = "ZYX"
    model = CARE(config=None,
                 name=care_model_path.name,
                 basedir=care_model_path.parent)
    care_result_root = lfd_path.parent.parent / "care" / "pred"
    care_result_root.mkdir(parents=True, exist_ok=True)
    print("saving CARE reconstructions to", care_result_root)

    for file_path in tqdm(list(lfd_path.glob("*tif"))):
        result_path = care_result_root / file_path.name
        if result_path.exists() and not overwrite:
            continue

        x = imread(str(file_path)).squeeze()
        restored = model.predict(x, axes)
        restored = restored.squeeze()[None, ...]
        save_tensor(result_path, restored)
Exemplo n.º 15
0
    def _build():
        with pytest.raises(FileNotFoundError):
            CARE(None, basedir=str(tmpdir))

        CARE(config, name='model', basedir=None)
        with pytest.raises(ValueError):
            CARE(None, basedir=None)

        CARE(config, basedir=str(tmpdir)).export_TF()

        with pytest.warns(UserWarning):
            CARE(config, name='model', basedir=str(tmpdir))
            CARE(config, name='model', basedir=str(tmpdir))
            CARE(None, name='model', basedir=str(tmpdir))
Exemplo n.º 16
0
def test_model_predict(tmpdir,config):
    rng = np.random.RandomState(42)
    normalizer, resizer = NoNormalizer(), NoResizer()

    K.clear_session()
    model = CARE(config,basedir=str(tmpdir))
    axes = config.axes

    def _predict(imdims,axes):
        img = rng.uniform(size=imdims)
        # print(img.shape, axes, config.n_channel_out)
        if config.probabilistic:
            prob = model.predict_probabilistic(img, axes, normalizer, resizer)
            mean, scale = prob.mean(), prob.scale()
            assert mean.shape == scale.shape
        else:
            mean = model.predict(img, axes, normalizer, resizer)

        if 'C' not in axes:
            if config.n_channel_out == 1:
                assert mean.shape == img.shape
            else:
                assert mean.shape == img.shape + (config.n_channel_out,)
        else:
            channel = axes_dict(axes)['C']
            imdims[channel] = config.n_channel_out
            assert mean.shape == tuple(imdims)


    imdims = list(rng.randint(20,40,size=config.n_dim))
    div_n = 2**config.unet_n_depth
    imdims = [(d//div_n)*div_n for d in imdims]

    if config.n_channel_in == 1:
        _predict(imdims,axes=axes.replace('C',''))

    channel = rng.randint(0,config.n_dim)
    imdims.insert(channel,config.n_channel_in)
    _axes = axes.replace('C','')
    _axes = _axes[:channel]+'C'+_axes[channel:]
    _predict(imdims,axes=_axes)
Exemplo n.º 17
0
def predict(path_data, model_name, n_tiles, axes, plot_prediction, stack_nb,
            filter_data, folder_name_save):
    """Predicts the output of the netowrk using a pre-trained model.
	
	Parameters
	----------
	path_data : str
		Path to input data to predict
	model_name : str	
		Name of the pre-trained model 
	n_tiles : tuple
		Size of tile to split the patches (helps avoid out of memory problems when predicting
	axes : str
		Semantic order of the channels
	plot_prediction : bool
		True or False whether the prediction is plot or not
	"""
    sep = '/'
    model_path = sep.join(model_name.split(sep)[:-1])
    #model_path = '/models/'
    print('-------', model_path)
    model_name = model_name.split(sep)[-1]

    model = CARE(config=None, name=model_name, basedir=model_path)

    for file_ in sorted(os.listdir(path_data)):

        if file_.endswith('.tif') and not pathlib.Path(
                os.path.dirname(os.getcwd()) + '/predicted/' + file_).exists():

            if filter_data in file_:
                reconstruction(model, file_, path_data, axes, n_tiles,
                               plot_prediction, folder_name_save)

            elif filter_data == 'all':
                reconstruction(model, file_, path_data, axes, n_tiles,
                               plot_prediction, folder_name_save)
Exemplo n.º 18
0
from tifffile import imread
from csbdeep.utils import Path, download_and_extract_zip_file, plot_some
from csbdeep.io import save_tiff_imagej_compatible
from csbdeep.models import CARE

# In[3]:

basedirLow = '/local/u934/private/v_kapoor/ProjectionTraining/MasterLow/VeryLow/'
basedirResults3D = '/local/u934/private/v_kapoor/ProjectionTraining/MasterLow/NotsoLow/'
ModelName = 'BorialisS1S2FlorisMidNoiseModel'
BaseDir = '/data/u934/service_imagerie/v_kapoor/CurieDeepLearningModels/'
Path(basedirResults3D).mkdir(exist_ok=True)

# In[4]:

model = CARE(config=None, name=ModelName, basedir=BaseDir)

# In[6]:

Raw_path = os.path.join(basedirLow, '*tif')

axes = 'ZYX'
smallaxes = 'YX'
filesRaw = glob.glob(Raw_path)

filesRaw.sort
print(len(filesRaw))
for fname in filesRaw:

    x = imread(fname)
    print(x.shape)
Exemplo n.º 19
0
basedir = '/run/user/1000/gvfs/smb-share:server=isiserver.curie.net,share=u934/equipe_bellaiche/m_gracia/20210316/partie1'

basedirResults3D = basedir + '/Restored'
basedirResults2D = basedir + '/Projected'
basedirResults3Dextended = basedirResults3D + '/Restored'
basedirResults2Dextended = basedirResults2D + '/Projected'

Model_Dir = '/run/media/sancere/DATA/Lucas_Model_to_use/CARE/'

# In[3]:

RestorationModel = 'CARE_restoration_Borealis_Bin1'
ProjectionModel = 'CARE_projection_Borealis_Bin1'

RestorationModel = CARE(config=None, name=RestorationModel, basedir=Model_Dir)
ProjectionModel = ProjectionCARE(config=None,
                                 name=ProjectionModel,
                                 basedir=Model_Dir)

# In[5]:

Path(basedirResults3D).mkdir(exist_ok=True)
Path(basedirResults2D).mkdir(exist_ok=True)

Raw_path = os.path.join(basedir, '*TIF')  #tif or TIF be careful

axes = 'ZYX'  #projection axes : 'YX'

filesRaw = glob.glob(Raw_path)
Exemplo n.º 20
0
def train(Training_source=".",
          Training_target=".",
          model_name="No_name",
          model_path=".",
          Visual_validation_after_training=True,
          number_of_epochs=100,
          patch_size=64,
          number_of_patches=10,
          Use_Default_Advanced_Parameters=True,
          number_of_steps=300,
          batch_size=32,
          percentage_validation=15):
    '''
  Main function of the script. Train the model an save in model_path

  Parameters
  ----------
  Training_source : (str) Path to the noisy images
  Training_target : (str) Path to the GT images
  model_name : (str) name of the model
  model_path : (str) path of the model
  Visual_validation_after_training : (bool) Predict a random image after training
  Number_of_epochs : (int) epochs
  path_size : (int) patch sizes
  number_of_patches : (int) number of patches for each image
  User_Default_Advances_Parameters : (bool) Use default parameters for the training
  number_of_steps : (int) number of steps
  batch_size : (int) batch size
  percentage_validation : (int) percentage validation

  Return
  -------
  void
  '''
    OutputFile = Training_target + "/*.tif"
    InputFile = Training_source + "/*.tif"
    base = "/content/"
    training_data = base + "/my_training_data.npz"
    if (Use_Default_Advanced_Parameters):
        print("Default advanced parameters enabled")
        batch_size = 64
        percentage_validation = 10

    percentage = percentage_validation / 100
    #here we check that no model with the same name already exist, if so delete
    if os.path.exists(model_path + '/' + model_name):
        shutil.rmtree(model_path + '/' + model_name)

    # The shape of the images.
    x = imread(InputFile)
    y = imread(OutputFile)

    print('Loaded Input images (number, width, length) =', x.shape)
    print('Loaded Output images (number, width, length) =', y.shape)
    print("Parameters initiated.")

    # RawData Object

    # This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.
    # This file is saved in .npz format and later called when loading the trainig data.

    raw_data = data.RawData.from_folder(basepath=base,
                                        source_dirs=[Training_source],
                                        target_dir=Training_target,
                                        axes='CYX',
                                        pattern='*.tif*')

    X, Y, XY_axes = data.create_patches(raw_data,
                                        patch_filter=None,
                                        patch_size=(patch_size, patch_size),
                                        n_patches_per_image=number_of_patches)

    print('Creating 2D training dataset')
    training_path = model_path + "/rawdata"
    rawdata1 = training_path + ".npz"
    np.savez(training_path, X=X, Y=Y, axes=XY_axes)

    # Load Training Data
    (X, Y), (X_val,
             Y_val), axes = load_training_data(rawdata1,
                                               validation_split=percentage,
                                               verbose=True)
    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]
    #Show_patches(X,Y)

    #Here we automatically define number_of_step in function of training data and batch size
    if (Use_Default_Advanced_Parameters):
        number_of_steps = int(X.shape[0] / batch_size) + 1

    print(number_of_steps)

    #Here we create the configuration file

    config = Config(axes,
                    n_channel_in,
                    n_channel_out,
                    probabilistic=False,
                    train_steps_per_epoch=number_of_steps,
                    train_epochs=number_of_epochs,
                    unet_kern_size=5,
                    unet_n_depth=3,
                    train_batch_size=batch_size,
                    train_learning_rate=0.0004)

    print(config)
    vars(config)

    # Compile the CARE model for network training
    model_training = CARE(config, model_name, basedir=model_path)

    if (Visual_validation_after_training):
        Cell_executed = 1

    import time
    start = time.time()

    #@markdown ##Start Training

    # Start Training
    history = model_training.train(X, Y, validation_data=(X_val, Y_val))

    print("Training, done.")
    if (Visual_validation_after_training):
        Predict_a_image(Training_source, Training_target, model_path,
                        model_training)

    # Displaying the time elapsed for training
    dt = time.time() - start
    min, sec = divmod(dt, 60)
    hour, min = divmod(min, 60)
    print("Time elapsed:", hour, "hour(s)", min, "min(s)", round(sec),
          "sec(s)")

    Show_loss_function(history, model_path)
Exemplo n.º 21
0
from skimage.io import ImageCollection, imread, imsave
from skimage import img_as_float, img_as_ubyte

args = ArgumentParser()
args.add_argument("--base_dir")
args.add_argument("--model")
args.add_argument("--out_dir")
args.add_argument("--data")
#args.add_argument("base_dir")
#args.add_argument("model")
#args.add_argument("out_dir")
#args.add_argument("data")
args.add_argument("--is_3d", default=False)
args = args.parse_args()

model = CARE(config=None, name=args.model, basedir=args.base_dir)
#data = ImageCollection("training_data/val/low_snr_extracted_z/*.tif")
data = ImageCollection(args.data)
axes = "ZYX" if bool(args.is_3d) else "YX"

if not exists(args.out_dir):
    makedirs(args.out_dir)

for i in range(len(data)):
    im = img_as_float(data[i])
    r = model.predict(im, axes)
    #r = (r - r.min()) / (r.max() - r.min())
    r = img_as_ubyte(r)
    imsave(join(args.out_dir, f"{args.model}_{basename(data.files[i])}"), r)

Exemplo n.º 22
0
#val_data = (val_data["X"], val_data["Y"])
#val_axes = tr_axes # we assume that both training and validation data are saved in the same format

is_3d = bool(args.is_3d)
axes = "ZYX" if is_3d else "YX"
n_dim = 3 if is_3d else 2
config = Config(axes,
                n_dim=n_dim,
                n_channel_in=1,
                n_channel_out=1,
                probabilistic=int(args.probabilistic),
                train_batch_size=int(args.batch_size),
                unet_kern_size=3)

for i in range(int(args.num_models)):
    model = CARE(config, f"model_{i}", args.models_dir)
    train_history = model.train(tr_data[0],
                                tr_data[1],
                                validation_data=val_data,
                                epochs=int(args.epochs),
                                steps_per_epoch=int(args.steps_per_epoch))

exit()

plot_history(train_history, ["loss", "val_loss"],
             ["mse", "val_mse", "mae", "val_mae"])


def crop_to_even(img):
    shape = img.shape
    new_shape = (shape[0] - (shape[0] % 2), shape[1] - (shape[1] % 2))
Exemplo n.º 23
0
    f.write('{:.2f}\t{:.2f}\t{:.2f}\n'.format(x0 + dlx, y0 + dly, z0 + dlz))


def load_position_log(log_fname):
    try:
        xyz_array = np.loadtxt(log_fname, skiprows=3, delimiter='\t')
        print(xyz_array.shape)
    except:
        print('Unable to open ', log_fname)
        xyz_array = np.zeros((0, 3))
    if xyz_array.shape == (3, ):
        xyz_array = np.zeros((0, 3))
    return xyz_array


model = CARE(config=None, name='bactn-gfp', basedir='CARE_Models')
axes = 'ZYX'

# fig, ax = pl.subplots()
# pl.ion()
# pl.show()
print('Ready to run')
while 1 != 0:
    log_list = [f for f in os.listdir('.') if ('.LOG' in f)]
    #print(log_list)
    for logfile in log_list:
        try:
            dirlog = open(logfile, 'r')
        except:
            break
        success = True
Exemplo n.º 24
0
def test_model_predict_tiled():
    """
    Test that tiled prediction yields the same
    or similar result as compared to predicting
    the whole image at once.
    """
    rng = np.random.RandomState(42)
    configs = config_generator(
        axes=['YX', 'ZYX'],
        n_channel_in=[1],
        n_channel_out=[1],
        probabilistic=[False],
        # unet_residual         = [False,True],
        unet_n_depth=[1, 2, 3],
        unet_kern_size=[3, 5],
        unet_n_first=[4],
        unet_last_activation=['linear'],
        # unet_input_shape      = [(None, None, 1)],
    )
    with tempfile.TemporaryDirectory() as tmpdir:
        normalizer, resizer = NoNormalizer(), NoResizer()

        for config in filter(lambda c: c.is_valid(), configs):
            K.clear_session()
            model = CARE(config, basedir=tmpdir)

            def _predict(imdims, axes, n_tiles):
                img = rng.uniform(size=imdims)
                # print(img.shape, axes)
                mean, scale = model._predict_mean_and_scale(img,
                                                            axes,
                                                            normalizer,
                                                            resizer,
                                                            n_tiles=1)
                mean_tiled, scale_tiled = model._predict_mean_and_scale(
                    img, axes, normalizer, resizer, n_tiles=n_tiles)
                assert mean.shape == mean_tiled.shape
                if config.probabilistic:
                    assert scale.shape == scale_tiled.shape
                error_max = np.max(np.abs(mean - mean_tiled))
                # print('n, k, err = {0}, {1}x{1}, {2}'.format(model.config.unet_n_depth, model.config.unet_kern_size, error_max))
                assert error_max < 1e-3
                return mean, mean_tiled

            imdims = list(rng.randint(100, 130, size=config.n_dim))
            if config.n_dim == 3:
                imdims[
                    0] = 32  # make one dim small, otherwise test takes too long
            div_n = 2**config.unet_n_depth
            imdims = [(d // div_n) * div_n for d in imdims]

            n_blocks = np.max(imdims) // div_n

            def _predict_wrapped(imdims, axes, n_tiles):
                if 0 < n_tiles <= n_blocks:
                    _predict(imdims, axes, n_tiles=n_tiles)
                else:
                    with pytest.warns(UserWarning):
                        _predict(imdims, axes, n_tiles=n_tiles)

            imdims.insert(0, config.n_channel_in)
            axes = config.axes.replace('C', '')
            # return _predict(imdims,'C'+axes,n_tiles=(3,4))

            # tile one dimension
            for n_tiles in (0, 2, 3, 6, n_blocks + 1):
                if config.n_channel_in == 1:
                    _predict_wrapped(imdims[1:], axes, n_tiles)
                _predict_wrapped(imdims, 'C' + axes, n_tiles)

            # tile two dimensions
            for n_tiles in product((2, 4), (3, 5)):
                _predict(imdims, 'C' + axes, n_tiles)

            # tile three dimensions
            if config.n_dim == 3:
                _predict(imdims, 'C' + axes, (2, 3, 4))
Exemplo n.º 25
0
 def _build():
     CARE(config, basedir=tmpdir)
Exemplo n.º 26
0
    def Train(self):

        BinaryName = 'BinaryMask/'
        RealName = 'RealMask/'
        Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif'))
        Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True)
        Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True)
        RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif'))
        ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif'))
        ValRealMask = sorted(
            glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif'))

        print('Instance segmentation masks:', len(RealMask))
        if len(RealMask) == 0:

            print('Making labels')
            Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))

            for fname in Mask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = label(image)

                imwrite((self.BaseDir + '/' + RealName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))
        print('Semantic segmentation masks:', len(Mask))
        if len(Mask) == 0:
            print('Generating Binary images')

            RealfilesMask = sorted(
                glob.glob(self.BaseDir + '/' + RealName + '*tif'))

            for fname in RealfilesMask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = image > 0

                imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        if self.GenerateNPZ:

            raw_data = RawData.from_folder(
                basepath=self.BaseDir,
                source_dirs=['Raw/'],
                target_dir='BinaryMask/',
                axes='ZYX',
            )

            X, Y, XY_axes = create_patches(
                raw_data=raw_data,
                patch_size=(self.PatchZ, self.PatchY, self.PatchX),
                n_patches_per_image=self.n_patches_per_image,
                save_file=self.BaseDir + self.NPZfilename + '.npz',
            )

        # Training UNET model
        if self.TrainUNET:
            print('Training UNET model')
            load_path = self.BaseDir + self.NPZfilename + '.npz'

            (X, Y), (X_val,
                     Y_val), axes = load_training_data(load_path,
                                                       validation_split=0.1,
                                                       verbose=True)
            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_n_depth=self.depth,
                            train_epochs=self.epochs,
                            train_batch_size=self.batch_size,
                            unet_n_first=self.startfilter,
                            train_loss='mse',
                            unet_kern_size=self.kern_size,
                            train_learning_rate=self.learning_rate,
                            train_reduce_lr={
                                'patience': 5,
                                'factor': 0.5
                            })
            print(config)
            vars(config)

            model = CARE(config,
                         name='UNET' + self.model_name,
                         basedir=self.model_dir)

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + 'UNET' +
                                  self.copy_model_name + '/' +
                                  'weights_now.h5') and os.path.exists(
                                      self.model_dir + 'UNET' +
                                      self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    model.load_weights(self.copy_model_dir + 'UNET' +
                                       self.copy_model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_last.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_best.h5')

            history = model.train(X, Y, validation_data=(X_val, Y_val))

            print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

        if self.TrainSTAR:
            print('Training StarDistModel model with', self.backbone,
                  'backbone')
            self.axis_norm = (0, 1, 2)
            if self.CroppedLoad == False:
                assert len(Raw) > 1, "not enough training data"
                print(len(Raw))
                rng = np.random.RandomState(42)
                ind = rng.permutation(len(Raw))

                X_train = list(map(ReadFloat, Raw))
                Y_train = list(map(ReadInt, RealMask))
                self.Y = [
                    label(DownsampleData(y, self.DownsampleFactor))
                    for y in tqdm(Y_train)
                ]
                self.X = [
                    normalize(DownsampleData(x, self.DownsampleFactor),
                              1,
                              99.8,
                              axis=self.axis_norm) for x in tqdm(X_train)
                ]
                n_val = max(1, int(round(0.15 * len(ind))))
                ind_train, ind_val = ind[:-n_val], ind[-n_val:]

                self.X_val, self.Y_val = [self.X[i] for i in ind_val
                                          ], [self.Y[i] for i in ind_val]
                self.X_trn, self.Y_trn = [self.X[i] for i in ind_train
                                          ], [self.Y[i] for i in ind_train]

                print('number of images: %3d' % len(self.X))
                print('- training:       %3d' % len(self.X_trn))
                print('- validation:     %3d' % len(self.X_val))

            if self.CroppedLoad:
                self.X_trn = self.DataSequencer(Raw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_trn = self.DataSequencer(RealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)

                self.X_val = self.DataSequencer(ValRaw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_val = self.DataSequencer(ValRealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)
                self.train_sample_cache = False

            print(Config3D.__doc__)

            anisotropy = (1, 1, 1)
            rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy)

            if self.backbone == 'resnet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    resnet_n_blocks=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    resnet_kernel_size=(self.kern_size, self.kern_size,
                                        self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    resnet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1)

            if self.backbone == 'unet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    unet_n_depth=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    unet_kernel_size=(self.kern_size, self.kern_size,
                                      self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    unet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1,
                    train_sample_cache=False)

            print(conf)
            vars(conf)

            Starmodel = StarDist3D(conf,
                                   name=self.model_name,
                                   basedir=self.model_dir)
            print(
                Starmodel._axes_tile_overlap('ZYX'),
                os.path.exists(self.model_dir + self.model_name + '/' +
                               'weights_now.h5'))

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_now.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_now.h5')
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_last.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_last.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_last.h5')

                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_best.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_best.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_best.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_last.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_best.h5')

            historyStar = Starmodel.train(self.X_trn,
                                          self.Y_trn,
                                          validation_data=(self.X_val,
                                                           self.Y_val),
                                          epochs=self.epochs)
            print(sorted(list(historyStar.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(historyStar, ['loss', 'val_loss'], [
                'dist_relevant_mae', 'val_dist_relevant_mae',
                'dist_relevant_mse', 'val_dist_relevant_mse'
            ])
Exemplo n.º 27
0
def main():
    if not ('__file__' in locals() or '__file__' in globals()):
        print('running interactively, exiting.')
        sys.exit(0)

    # parse arguments
    parser, args = parse_args()
    args_dict = vars(args)

    # exit and show help if no arguments provided at all
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(0)

    # check for required arguments manually (because of argparse issue)
    required = ('--input-dir', '--input-axes', '--norm-pmin', '--norm-pmax',
                '--model-basedir', '--model-name', '--output-dir')
    for r in required:
        dest = r[2:].replace('-', '_')
        if args_dict[dest] is None:
            parser.print_usage(file=sys.stderr)
            print("%s: error: the following arguments are required: %s" %
                  (parser.prog, r),
                  file=sys.stderr)
            sys.exit(1)

    # show effective arguments (including defaults)
    if not args.quiet:
        print('Arguments')
        print('---------')
        pprint(args_dict)
        print()
        sys.stdout.flush()

    # logging function
    log = (lambda *a, **k: None) if args.quiet else tqdm.write

    # get list of input files and exit if there are none
    file_list = list(Path(args.input_dir).glob(args.input_pattern))
    if len(file_list) == 0:
        log("No files to process in '%s' with pattern '%s'." %
            (args.input_dir, args.input_pattern))
        sys.exit(0)

    # delay imports after checking to all required arguments are provided
    from tifffile import imread, imsave
    from csbdeep.utils.tf import keras_import
    K = keras_import('backend')
    from csbdeep.models import CARE
    from csbdeep.data import PercentileNormalizer
    sys.stdout.flush()
    sys.stderr.flush()

    # limit gpu memory
    if args.gpu_memory_limit is not None:
        from csbdeep.utils.tf import limit_gpu_memory
        limit_gpu_memory(args.gpu_memory_limit)

    # create CARE model and load weights, create normalizer
    K.clear_session()
    model = CARE(config=None, name=args.model_name, basedir=args.model_basedir)
    if args.model_weights is not None:
        print("Loading network weights from '%s'." % args.model_weights)
        model.load_weights(args.model_weights)
    normalizer = PercentileNormalizer(pmin=args.norm_pmin,
                                      pmax=args.norm_pmax,
                                      do_after=args.norm_undo)

    n_tiles = args.n_tiles
    if n_tiles is not None and len(n_tiles) == 1:
        n_tiles = n_tiles[0]

    processed = []

    # process all files
    for file_in in tqdm(file_list,
                        disable=args.quiet
                        or (n_tiles is not None and np.prod(n_tiles) > 1)):
        # construct output file name
        file_out = Path(args.output_dir) / args.output_name.format(
            file_path=str(file_in.relative_to(args.input_dir).parent),
            file_name=file_in.stem,
            file_ext=file_in.suffix,
            model_name=args.model_name,
            model_weights=Path(args.model_weights).stem
            if args.model_weights is not None else None)

        # checks
        (file_in.suffix.lower() in ('.tif', '.tiff')
         and file_out.suffix.lower() in ('.tif', '.tiff')) or _raise(
             ValueError('only tiff files supported.'))

        # load and predict restored image
        img = imread(str(file_in))
        restored = model.predict(img,
                                 axes=args.input_axes,
                                 normalizer=normalizer,
                                 n_tiles=n_tiles)

        # restored image could be multi-channel even if input image is not
        axes_out = axes_check_and_normalize(args.input_axes)
        if restored.ndim > img.ndim:
            assert restored.ndim == img.ndim + 1
            assert 'C' not in axes_out
            axes_out += 'C'

        # convert data type (if necessary)
        restored = restored.astype(np.dtype(args.output_dtype), copy=False)

        # save to disk
        if not args.dry_run:
            file_out.parent.mkdir(parents=True, exist_ok=True)
            if args.imagej_tiff:
                save_tiff_imagej_compatible(str(file_out), restored, axes_out)
            else:
                imsave(str(file_out), restored)

        processed.append((file_in, file_out))

    # print summary of processed files
    if not args.quiet:
        sys.stdout.flush()
        sys.stderr.flush()
        n_processed = len(processed)
        len_processed = len(str(n_processed))
        log('Finished processing %d %s' %
            (n_processed, 'files' if n_processed > 1 else 'file'))
        log('-' * (26 + len_processed if n_processed > 1 else 26))
        for i, (file_in, file_out) in enumerate(processed):
            len_file = max(len(str(file_in)), len(str(file_out)))
            log(('{:>%d}. in : {:>%d}' % (len_processed, len_file)).format(
                1 + i, str(file_in)))
            log(('{:>%d}  out: {:>%d}' % (len_processed, len_file)).format(
                '', str(file_out)))
Exemplo n.º 28
0
                         unet_n_first=48,
                         unet_kern_size=7,
                         train_loss='mae',
                         train_epochs=150,
                         train_learning_rate=1.0E-4,
                         train_batch_size=1,
                         train_reduce_lr={
                             'patience': 5,
                             'factor': 0.5
                         })
print(config)
vars(config)

# In[6]:

model = CARE(config=config, name=ModelName, basedir=ModelDir)
#input_weights = ModelDir + ModelName + '/' +'weights_best.h5'
#model.load_weights(input_weights)

# In[ ]:

history = model.train(X, Y, validation_data=(X_val, Y_val))

# In[ ]:

print(sorted(list(history.history.keys())))
plt.figure(figsize=(16, 5))
plot_history(history, ['loss', 'val_loss'],
             ['mse', 'val_mse', 'mae', 'val_mae'])

# In[ ]:
Exemplo n.º 29
0
        help="the name of your model (will be stored in model_dir)",
        default="my_model")
    parser.add_argument(
        "--path_to_test_image",
        help="the path to the test high/low pair, if available",
        default="test")
    # parser.add_argument(
    #   "-o",
    #   "--output_filename",
    #   help="the name of the out"
    # )
    parser.add_argument("-p", "--plot", action="store_true", default=True)
    args = parser.parse_args()

    # Load the trained model
    model = CARE(config=None, name=args.model_name, basedir=args.model_dir)

    # Read the test images
    x = imread(
        os.path.join(args.base_dir, args.path_to_test_image, "low_snr.tif"))
    y = imread(
        os.path.join(args.base_dir, args.path_to_test_image, "high_snr.tif"))

    restored = model.predict(x, axes="YX")  # , n_tiles=(1,4,4))

    # Save the restored image
    save_tiff_imagej_compatible(
        os.path.join(args.base_dir, args.path_to_test_image, "predicted.tif"),
        restored,
        axes="YX")
Exemplo n.º 30
0
basedir = '/run/user/1000/gvfs/smb-share:server=isiserver.curie.net,share=u934/equipe_bellaiche/el_alpar/210609_ON_ActTolloRNAi_lateral'

basedirResults3D = basedir + '/Restored'
basedirResults2D = basedir + '/Projected'
basedirResults3Dextended = basedirResults3D + '/Restored'
basedirResults2Dextended = basedirResults2D + '/Projected'

Model_Dir = '/run/media/sancere/DATA/Lucas_Model_to_use/CARE/'
#Test_change_email_adress_to_see_commit

# In[3]:

RestorationModel = 'CARE_restoration_SpinWideFRAP4_Bin1_3Gfp'
ProjectionModel = 'CARE_projection_SpinWideFRAP4_Bin1_3Gfp'

RestorationModel = CARE(config=None, name=RestorationModel, basedir=Model_Dir)
ProjectionModel = ProjectionCARE(config=None,
                                 name=ProjectionModel,
                                 basedir=Model_Dir)

# In[5]:

#Path(basedirResults3D).mkdir(exist_ok = True)
Path(basedirResults2D).mkdir(exist_ok=True)

Raw_path = os.path.join(basedir, '*TIF')  #tif or TIF be careful

axes = 'ZYX'  #projection axes : 'YX'

filesRaw = glob.glob(Raw_path)