예제 #1
0
    def export_TF(self, fname=None, single_output=True, upsample_grid=True):
        """Export model to TensorFlow's SavedModel format that can be used e.g. in the Fiji plugin

        Parameters
        ----------
        fname : str
            Path of the zip file to store the model
            If None, the default path "<modeldir>/TF_SavedModel.zip" is used
        single_output: bool
            If set, concatenates the two model outputs into a single output (note: this is currently mandatory for further use in Fiji)
        upsample_grid: bool
            If set, upsamples the output to the input shape (note: this is currently mandatory for further use in Fiji)
        """
        Concatenate, UpSampling2D, UpSampling3D, Conv2DTranspose, Conv3DTranspose = keras_import(
            'layers', 'Concatenate', 'UpSampling2D', 'UpSampling3D',
            'Conv2DTranspose', 'Conv3DTranspose')
        Model = keras_import('models', 'Model')

        if self.basedir is None and fname is None:
            raise ValueError(
                "Need explicit 'fname', since model directory not available (basedir=None)."
            )

        grid = self.config.grid
        prob = self.keras_model.outputs[0]
        dist = self.keras_model.outputs[1]
        assert self.config.n_dim in (2, 3)

        if upsample_grid and any(g > 1 for g in grid):
            # CSBDeep Fiji plugin needs same size input/output
            # -> we need to upsample the outputs if grid > (1,1)
            # note: upsampling prob with a transposed convolution creates sparse
            #       prob output with less candidates than with standard upsampling
            conv_transpose = Conv2DTranspose if self.config.n_dim == 2 else Conv3DTranspose
            upsampling = UpSampling2D if self.config.n_dim == 2 else UpSampling3D
            prob = conv_transpose(1, (1, ) * self.config.n_dim,
                                  strides=grid,
                                  padding='same',
                                  kernel_initializer='ones',
                                  use_bias=False)(prob)
            dist = upsampling(grid)(dist)

        inputs = self.keras_model.inputs[0]
        outputs = Concatenate()([prob, dist
                                 ]) if single_output else [prob, dist]
        csbdeep_model = Model(inputs, outputs)

        fname = (self.logdir /
                 'TF_SavedModel.zip') if fname is None else Path(fname)
        export_SavedModel(csbdeep_model, str(fname))
        return csbdeep_model
예제 #2
0
def test_stardistdata_sequence():
    from stardist.models import StarDistData3D
    from stardist import Rays_GoldenSpiral
    from csbdeep.utils.tf import keras_import
    Sequence = keras_import('utils', 'Sequence')

    x = np.zeros((10, 32, 48, 64), np.uint16)
    x[:, 10:-10, 10:-10] = 1

    class MyData(Sequence):
        def __init__(self, dtype):
            self.dtype = dtype

        def __getitem__(self, n):
            return x[n]

        def __len__(self):
            return len(x)

    X = MyData(np.float32)
    Y = MyData(np.uint16)
    s = StarDistData3D(X,
                       Y,
                       batch_size=1,
                       patch_size=(32, 32, 32),
                       rays=Rays_GoldenSpiral(64),
                       length=1)
    (img, ), (prob, dist) = s[0]
    return (img, ), (prob, dist), s
예제 #3
0
def test_stardistdata_sequence():
    from stardist.models import StarDistData2D
    from csbdeep.utils.tf import keras_import
    Sequence = keras_import('utils', 'Sequence')

    x = np.zeros((100, 100), np.uint16)
    x[10:-10, 10:-10] = 1

    class MyData(Sequence):
        def __init__(self, dtype):
            self.dtype = dtype

        def __getitem__(self, n):
            return ((1 + n) * x).astype(self.dtype)

        def __len__(self):
            return 1000

    X = MyData(np.float32)
    Y = MyData(np.uint16)
    s = StarDistData2D(X,
                       Y,
                       batch_size=1,
                       patch_size=(100, 100),
                       n_rays=32,
                       length=1)
    (img, ), (prob, dist) = s[0]
    return (img, ), (prob, dist), s
예제 #4
0
import numpy as np
from pathlib import Path
from stardist.models import Config3D, StarDist3D
from csbdeep.utils.tf import keras_import
Sequence = keras_import('utils', 'Sequence')
LambdaCallback = keras_import('callbacks', 'LambdaCallback')
import argparse
import resource
import sys
import time
from functools import lru_cache


def print_memory():
    mem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
    unit = 1e9 if sys.platform == "darwin" else 1e6
    print(f"\n >>>> total memory used: {mem/unit:.2f} GB \n", flush=True)


class LargeSequence(Sequence):
    def __init__(self, n=1000, size=256):
        self.n = n
        self.data = np.zeros((size, size, size), np.uint16)
        self.data[1:-1, 1:-1, 1:-1] = 1

    # @lru_cache(maxsize=4)
    def __getitem__(self, n):
        return self.data.copy()

    def __len__(self):
        return self.n
예제 #5
0
from __future__ import absolute_import, print_function

from .model2d import Config2D, SplineDist2D, SplineDistData2D

from csbdeep.utils import backend_channels_last
from csbdeep.utils.tf import keras_import
K = keras_import('backend')
if not backend_channels_last():
    raise NotImplementedError(
        "Keras is configured to use the '%s' image data format, which is currently not supported. "
        "Please change it to use 'channels_last' instead: "
        "https://keras.io/getting-started/faq/#where-is-the-keras-configuration-file-stored"
        % K.image_data_format())
del backend_channels_last, K

from csbdeep.models import register_model, register_aliases, clear_models_and_aliases
# register pre-trained models and aliases (TODO: replace with updatable solution)

del register_model, register_aliases, clear_models_and_aliases
예제 #6
0
from __future__ import print_function, unicode_literals, absolute_import, division

import numpy as np
import sys
import warnings
import math
from tqdm import tqdm
from collections import namedtuple
from pathlib import Path

from csbdeep.models.base_model import BaseModel
from csbdeep.utils.tf import export_SavedModel, keras_import, IS_TF_1, CARETensorBoard

import tensorflow as tf
K = keras_import('backend')
Sequence = keras_import('utils', 'Sequence')
Adam = keras_import('optimizers', 'Adam')
ReduceLROnPlateau, TensorBoard = keras_import('callbacks', 'ReduceLROnPlateau', 'TensorBoard')

from csbdeep.utils import _raise, backend_channels_last, axes_check_and_normalize, axes_dict, load_json, save_json
from csbdeep.internals.predict import tile_iterator
from csbdeep.internals.train import RollingSequence
from csbdeep.data import Resizer

from ..sample_patches import get_valid_inds
from ..utils import _is_power_of_2, optimize_threshold

import splinegenerator as sg

def generic_masked_loss(mask, n_params, loss, weights=1, norm_by_mask=True, reg_weight=0, reg_penalty=K.abs):
    def _loss(y_true, y_pred):  
예제 #7
0
def main():
    if not ('__file__' in locals() or '__file__' in globals()):
        print('running interactively, exiting.')
        sys.exit(0)

    # parse arguments
    parser, args = parse_args()
    args_dict = vars(args)

    # exit and show help if no arguments provided at all
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(0)

    # check for required arguments manually (because of argparse issue)
    required = ('--input-dir', '--input-axes', '--norm-pmin', '--norm-pmax',
                '--model-basedir', '--model-name', '--output-dir')
    for r in required:
        dest = r[2:].replace('-', '_')
        if args_dict[dest] is None:
            parser.print_usage(file=sys.stderr)
            print("%s: error: the following arguments are required: %s" %
                  (parser.prog, r),
                  file=sys.stderr)
            sys.exit(1)

    # show effective arguments (including defaults)
    if not args.quiet:
        print('Arguments')
        print('---------')
        pprint(args_dict)
        print()
        sys.stdout.flush()

    # logging function
    log = (lambda *a, **k: None) if args.quiet else tqdm.write

    # get list of input files and exit if there are none
    file_list = list(Path(args.input_dir).glob(args.input_pattern))
    if len(file_list) == 0:
        log("No files to process in '%s' with pattern '%s'." %
            (args.input_dir, args.input_pattern))
        sys.exit(0)

    # delay imports after checking to all required arguments are provided
    from tifffile import imread, imsave
    from csbdeep.utils.tf import keras_import
    K = keras_import('backend')
    from csbdeep.models import CARE
    from csbdeep.data import PercentileNormalizer
    sys.stdout.flush()
    sys.stderr.flush()

    # limit gpu memory
    if args.gpu_memory_limit is not None:
        from csbdeep.utils.tf import limit_gpu_memory
        limit_gpu_memory(args.gpu_memory_limit)

    # create CARE model and load weights, create normalizer
    K.clear_session()
    model = CARE(config=None, name=args.model_name, basedir=args.model_basedir)
    if args.model_weights is not None:
        print("Loading network weights from '%s'." % args.model_weights)
        model.load_weights(args.model_weights)
    normalizer = PercentileNormalizer(pmin=args.norm_pmin,
                                      pmax=args.norm_pmax,
                                      do_after=args.norm_undo)

    n_tiles = args.n_tiles
    if n_tiles is not None and len(n_tiles) == 1:
        n_tiles = n_tiles[0]

    processed = []

    # process all files
    for file_in in tqdm(file_list,
                        disable=args.quiet
                        or (n_tiles is not None and np.prod(n_tiles) > 1)):
        # construct output file name
        file_out = Path(args.output_dir) / args.output_name.format(
            file_path=str(file_in.relative_to(args.input_dir).parent),
            file_name=file_in.stem,
            file_ext=file_in.suffix,
            model_name=args.model_name,
            model_weights=Path(args.model_weights).stem
            if args.model_weights is not None else None)

        # checks
        (file_in.suffix.lower() in ('.tif', '.tiff')
         and file_out.suffix.lower() in ('.tif', '.tiff')) or _raise(
             ValueError('only tiff files supported.'))

        # load and predict restored image
        img = imread(str(file_in))
        restored = model.predict(img,
                                 axes=args.input_axes,
                                 normalizer=normalizer,
                                 n_tiles=n_tiles)

        # restored image could be multi-channel even if input image is not
        axes_out = axes_check_and_normalize(args.input_axes)
        if restored.ndim > img.ndim:
            assert restored.ndim == img.ndim + 1
            assert 'C' not in axes_out
            axes_out += 'C'

        # convert data type (if necessary)
        restored = restored.astype(np.dtype(args.output_dtype), copy=False)

        # save to disk
        if not args.dry_run:
            file_out.parent.mkdir(parents=True, exist_ok=True)
            if args.imagej_tiff:
                save_tiff_imagej_compatible(str(file_out), restored, axes_out)
            else:
                imsave(str(file_out), restored)

        processed.append((file_in, file_out))

    # print summary of processed files
    if not args.quiet:
        sys.stdout.flush()
        sys.stderr.flush()
        n_processed = len(processed)
        len_processed = len(str(n_processed))
        log('Finished processing %d %s' %
            (n_processed, 'files' if n_processed > 1 else 'file'))
        log('-' * (26 + len_processed if n_processed > 1 else 26))
        for i, (file_in, file_out) in enumerate(processed):
            len_file = max(len(str(file_in)), len(str(file_out)))
            log(('{:>%d}. in : {:>%d}' % (len_processed, len_file)).format(
                1 + i, str(file_in)))
            log(('{:>%d}  out: {:>%d}' % (len_processed, len_file)).format(
                '', str(file_out)))
예제 #8
0
from __future__ import print_function, unicode_literals, absolute_import, division

import numpy as np
import warnings
import math
from tqdm import tqdm

from csbdeep.models import BaseConfig
from csbdeep.internals.blocks import unet_block
from csbdeep.utils import _raise, backend_channels_last, axes_check_and_normalize, axes_dict
from csbdeep.utils.tf import keras_import, IS_TF_1, CARETensorBoard, CARETensorBoardImage
from skimage.segmentation import clear_border
from distutils.version import LooseVersion

keras = keras_import()
K = keras_import('backend')
Input, Conv2D, MaxPooling2D = keras_import('layers', 'Input', 'Conv2D',
                                           'MaxPooling2D')
Model = keras_import('models', 'Model')

from .base import SplineDistBase, SplineDistDataBase
from ..sample_patches import sample_patches
from ..utils import edt_prob, _normalize_grid
from ..geometry import spline_dist, dist_to_coord, polygons_to_label
from ..nms import non_maximum_suppression


class SplineDistData2D(SplineDistDataBase):
    def __init__(self,
                 X,
                 Y,
예제 #9
0
from __future__ import print_function, unicode_literals, absolute_import, division

import numpy as np
import warnings
import math
from tqdm import tqdm

from csbdeep.models import BaseConfig
from csbdeep.internals.blocks import conv_block3, unet_block, resnet_block
from csbdeep.utils import _raise, backend_channels_last, axes_check_and_normalize, axes_dict
from csbdeep.utils.tf import keras_import, IS_TF_1, CARETensorBoard, CARETensorBoardImage
from distutils.version import LooseVersion

keras = keras_import()
K = keras_import('backend')
Input, Conv3D, MaxPooling3D, UpSampling3D, Add, Concatenate = keras_import(
    'layers', 'Input', 'Conv3D', 'MaxPooling3D', 'UpSampling3D', 'Add',
    'Concatenate')
Model = keras_import('models', 'Model')

from .base import StarDistBase, StarDistDataBase
from ..sample_patches import sample_patches
from ..utils import edt_prob, _normalize_grid
from ..matching import relabel_sequential
from ..geometry import star_dist3D, polyhedron_to_label
from ..rays3d import Rays_GoldenSpiral, rays_from_json
from ..nms import non_maximum_suppression_3d


class StarDistData3D(StarDistDataBase):
    def __init__(self,
예제 #10
0
파일: utils.py 프로젝트: stardist/stardist
import os
import numpy as np
from tifffile import imread
from skimage.measure import label
from scipy.ndimage.filters import gaussian_filter
from pathlib import Path
from timeit import default_timer
from csbdeep.utils.tf import keras_import
Sequence = keras_import('utils', 'Sequence')


class NumpySequence(Sequence):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, n):
        return self.data[n]

    def __len__(self):
        return len(self.data)


class Timer(object):
    def __init__(self, message="elapsed", fmt=" {1000*t:.2f} ms"):
        self.message = message
        self.fmt = fmt

    def __enter__(self):
        self.start = default_timer()
        return self