Пример #1
0
    def __check_dims(self, im, block_size, overlap):
        # To-Do: deal with float block_size and overlap
        def __broadcast(val):
            val = [val for i in im.shape] if np.isscalar(val) else val
            return val

        len(im.shape) == 3 or _raise(ValueError('the input image must be in shape [depth, height, width]'))
        block_size = __broadcast(block_size)
        overlap    = __broadcast(overlap)

        len(block_size) == len(im.shape) or _raise(ValueError("ndim of block_size ({}) mismatch that of image size ({})".format(block_size, im.shape)))
        len(overlap) == len(im.shape) or _raise(ValueError("ndim of overlap ({}) mismatch that of image size ({})".format(overlap, im.shape)))
        
        # block_size = [b if b <= i else i for b, i in zip(block_size, im.shape)]

        overlap = [i if i > 1 else i * s for i, s in zip(overlap, block_size)]
        overlap = [0 if b >= s else i for i, b, s in zip(overlap, block_size, im.shape)] # no overlap along the dims where the image size equal to the block size
        overlap = [i if i % 2 == 0 else i + 1 for i in overlap]                          # overlap must be even number

        block_size = [b - 2 * i for b, i in zip(block_size, overlap)]                    # real block size when inference
        
        overlap    = [int(i) for i in overlap]
        block_size = [int(i) for i in block_size]


        print('block size (overlap excluded) : {} overlap : {}'.format(block_size, overlap))

        return block_size, overlap
Пример #2
0
 def start_response(status, headerlist, exc_info=None):
     if exc_info:
         _raise(*exc_info)
     rs.status = status
     for name, value in headerlist:
         rs.add_header(name, value)
     return rs.body.append
Пример #3
0
    def predict(self, im, block_size, overlap, normalization='fixed', **kwargs):
        normalization in ['fixed', 'percentile'] or _raise(ValueError('unknown normailze mode:%s' % normalization))
        norm_fn = self.__normalize_fixed if normalization == 'fixed' else self.__normalize_percentile

        im_dtype = im.dtype
        im_dtype in [np.uint8, np.uint16] or _raise(ValueError('unknown image dtype:%s' % im_dtype))
        im = norm_fn(im, **kwargs)
        
    
        print('normalized to [%.4f, %.4f]' % (np.min(im), np.max(im)))
        sr = self.predict_without_norm(im, block_size, overlap)
        return self.__reverse_norm(sr, normalize_mode=normalization)
Пример #4
0
def sample_patches_from_multiple_stacks(datas, patch_size, n_samples, datas_mask=None, patch_filter=None, verbose=False):
    """ sample matching patches of size `patch_size` from all arrays in `datas` """

    # TODO: some of these checks are already required in 'create_patches'
    len(patch_size)==datas[0].ndim or _raise(ValueError())

    if not all(( a.shape == datas[0].shape for a in datas )):
        raise ValueError("all input shapes must be the same: %s" % (" / ".join(str(a.shape) for a in datas)))

    if not all(( 0 < s <= d for s,d in zip(patch_size,datas[0].shape) )):
        raise ValueError("patch_size %s negative or larger than data shape %s along some dimensions" % (str(patch_size), str(datas[0].shape)))

    if patch_filter is None:
        patch_mask = np.ones(datas[0].shape,dtype=np.bool)
    else:
        patch_mask = patch_filter(datas, patch_size)

    if datas_mask is not None:
        # TODO: Test this
        warnings.warn('Using pixel masks for raw/transformed images not tested.')
        datas_mask.shape == datas[0].shape or _raise(ValueError())
        datas_mask.dtype == np.bool or _raise(ValueError())
        from scipy.ndimage.filters import minimum_filter
        patch_mask &= minimum_filter(datas_mask, patch_size, mode='constant', cval=False)

    # get the valid indices

    border_slices = tuple([slice(s // 2, d - s + s // 2 + 1) for s, d in zip(patch_size, datas[0].shape)])
    valid_inds = np.where(patch_mask[border_slices])

    if len(valid_inds[0]) == 0:
        raise ValueError("'patch_filter' didn't return any region to sample from")

    valid_inds = [v + s.start for s, v in zip(border_slices, valid_inds)]

    # sample
    sample_inds = np.random.choice(len(valid_inds[0]), n_samples, replace=len(valid_inds[0])<n_samples)

    rand_inds = [v[sample_inds] for v in valid_inds]

    # res = [np.stack([data[r[0] - patch_size[0] // 2:r[0] + patch_size[0] - patch_size[0] // 2,
    #                  r[1] - patch_size[1] // 2:r[1] + patch_size[1] - patch_size[1] // 2,
    #                  r[2] - patch_size[2] // 2:r[2] + patch_size[2] - patch_size[2] // 2,
    #                  ] for r in zip(*rand_inds)]) for data in datas]

    res = [np.stack([data[tuple(slice(_r-(_p//2),_r+_p-(_p//2)) for _r,_p in zip(r,patch_size))] for r in zip(*rand_inds)]) for data in datas]

    return res
Пример #5
0
    def _check_inputs(self):
        print("checking training data dims ... ")
        hr_im_list = sorted(
            tl.files.load_file_list(path=self.train_hr_path,
                                    regx='.*.tif',
                                    printable=False))
        lr_im_list = sorted(
            tl.files.load_file_list(path=self.train_lr_path,
                                    regx='.*.tif',
                                    printable=False))
        len(hr_im_list) == len(lr_im_list) or _raise(
            ValueError("Num of HR and LR not equal"))

        for hr_file, lr_file in zip(hr_im_list, lr_im_list):
            hr = imageio.volread(os.path.join(self.train_hr_path, hr_file))
            lr = imageio.volread(os.path.join(self.train_lr_path, lr_file))
            # print('checking dims: \n%s %s\n%s %s' % (hr_file, str(hr.shape), lr_file, str(lr.shape)))
            if 'factor' not in dir(self):
                self.factor = hr.shape[0] // lr.shape[0]
            valid_dim = [
                self.factor == hs / ls for hs, ls in zip(hr.shape, lr.shape)
            ]
            if not all(valid_dim):
                raise (ValueError(
                    'dims mismatch: \n%s %s\n%s %s' %
                    (hr_file, str(hr.shape), lr_file, str(lr.shape))))
Пример #6
0
    def _sample_and_get_statistics(self,
                                   file_list,
                                   interval=100,
                                   low=2,
                                   high=99.8):
        n_slices = len(file_list)
        n_slices > interval or _raise(
            ValueError('n_slices %d < sample interval %d' %
                       (n_slices, interval)))

        samples = [
            imread2d(file_list[i]) for i in range(0, n_slices, interval)
        ]
        # samples = imread2d(file_list[0])
        samples = np.asarray(samples, dtype=self.dtype)

        print('samples volume : %s' % str(samples.shape))
        _, h, w = samples.shape
        # h, w = samples.shape

        p_low = np.percentile(samples, low)
        p_high = np.percentile(samples, high)
        # p_low  = 0
        # p_high = 8000
        print('normalization thres: %.2f, %.2f' % (p_low, p_high))

        self.n_slices = n_slices
        self.p_low = p_low
        self.p_high = p_high
        self.width = w
        self.height = h
Пример #7
0
def norm_percentiles(percentiles=sample_percentiles(), relu_last=False):
    """Normalize extracted patches based on percentiles from corresponding raw image.

    Parameters
    ----------
    percentiles : tuple, optional
        A tuple (`pmin`, `pmax`) or a function that returns such a tuple, where the extracted patches
        are (affinely) normalized in such that a value of 0 (1) corresponds to the `pmin`-th (`pmax`-th) percentile
        of the raw image (default: :func:`sample_percentiles`).
    relu_last : bool, optional
        Flag to indicate whether the last activation of the CARE network is/will be using
        a ReLU activation function (default: ``False``)

    Return
    ------
    function
        Function that does percentile-based normalization to be used in :func:`create_patches`.

    Raises
    ------
    ValueError
        Illegal arguments.

    Todo
    ----
    ``relu_last`` flag problematic/inelegant.

    """
    if callable(percentiles):
        _tmp = percentiles()
        _valid_low_high_percentiles(_tmp) or _raise(ValueError(_tmp))
        get_percentiles = percentiles
    else:
        _valid_low_high_percentiles(percentiles) or _raise(ValueError(percentiles))
        get_percentiles = lambda: percentiles

    def _normalize(patches_x,patches_y, x,y,mask,channel):
        pmins, pmaxs = zip(*(get_percentiles() for _ in patches_x))
        percentile_axes = None if channel is None else tuple((d for d in range(x.ndim) if d != channel))
        _perc = lambda a,p: np.percentile(a,p,axis=percentile_axes,keepdims=True)
        patches_x_norm = normalize_mi_ma(patches_x, _perc(x,pmins), _perc(x,pmaxs))
        if relu_last:
            pmins = np.zeros_like(pmins)
        patches_y_norm = normalize_mi_ma(patches_y, _perc(y,pmins), _perc(y,pmaxs))
        return patches_x_norm, patches_y_norm

    return _normalize
Пример #8
0
    def __init__(
            self,
            lr_size,
            hr_size,
            train_lr_path,
            train_hr_path,
            test_lr_path=None,  # if None, the first 4 image pairs in the training set will be used as the test data
            test_hr_path=None,
            mr_size=None,
            train_mr_path=None,
            test_mr_path=None,
            valid_lr_path=None,
            dtype=np.float32,
            normalization='fixed',
            keep_all_blocks=False,
            transforms=None,  # [trans_fn_for_lr, trans_fn_for_hr] or None
            shuffle=True,
            **kwargs  # keyword arguments for transform function
    ):

        self.lr_size = lr_size
        self.hr_size = hr_size
        self.mr_size = mr_size

        self.train_lr_path = train_lr_path
        self.train_hr_path = train_hr_path
        self.train_mr_path = train_mr_path
        self.test_lr_path = test_lr_path
        self.test_hr_path = test_hr_path
        self.test_mr_path = test_mr_path
        self.valid_lr_path = valid_lr_path

        normalization in ['fixed', 'percentile'] or _raise(
            ValueError('unknown normalization mode: %' % normalization))
        self.normalize = normalization

        self.keep_all_blocks = keep_all_blocks
        self.transforms = transforms if transforms is not None else [
            None, None
        ]
        self.transforms_args = kwargs
        self.dtype = dtype
        self.shuffle = shuffle

        ## if LR measurement is designated for validation during the trianing
        self.hasValidation = False
        if valid_lr_path is not None:
            self.hasValidation = True

        self.hasMR = False
        if train_mr_path is not None:
            self.hasMR = True

        self.hasTest = False
        if test_hr_path is not None:
            self.hasTest = True

        self.prepared = False
Пример #9
0
def no_background_patches(threshold=0.4, percentile=99.9):

    """Returns a patch filter to be used by :func:`create_patches` to determine for each image pair which patches
    are eligible for sampling. The purpose is to only sample patches from "interesting" regions of the raw image that
    actually contain a substantial amount of non-background signal. To that end, a maximum filter is applied to the target image
    to find the largest values in a region.

    Parameters
    ----------
    threshold : float, optional
        Scalar threshold between 0 and 1 that will be multiplied with the (outlier-robust)
        maximum of the image (see `percentile` below) to denote a lower bound.
        Only patches with a maximum value above this lower bound are eligible to be sampled.
    percentile : float, optional
        Percentile value to denote the (outlier-robust) maximum of an image, i.e. should be close 100.

    Returns
    -------
    function
        Function that takes an image pair `(y,x)` and the patch size as arguments and
        returns a binary mask of the same size as the image (to denote the locations
        eligible for sampling for :func:`create_patches`). At least one pixel of the
        binary mask must be ``True``, otherwise there are no patches to sample.

    Raises
    ------
    ValueError
        Illegal arguments.
    """

    (np.isscalar(percentile) and 0 <= percentile <= 100) or _raise(ValueError())
    (np.isscalar(threshold)  and 0 <= threshold  <=   1) or _raise(ValueError())

    from scipy.ndimage.filters import maximum_filter
    def _filter(datas, patch_size, dtype=np.float32):
        image = datas[0]
        if dtype is not None:
            image = image.astype(dtype)
        # make max filter patch_size smaller to avoid only few non-bg pixel close to image border
        patch_size = [(p//2 if p>1 else p) for p in patch_size]
        filtered = maximum_filter(image, patch_size, mode='constant')
        return filtered > threshold * np.percentile(image,percentile)
    return _filter
Пример #10
0
def sample_percentiles(pmin=(1,3), pmax=(99.5,99.9)):
    """Sample percentile values from a uniform distribution.

    Parameters
    ----------
    pmin : tuple
        Tuple of two values that denotes the interval for sampling low percentiles.
    pmax : tuple
        Tuple of two values that denotes the interval for sampling high percentiles.

    Returns
    -------
    function
        Function without arguments that returns `(pl,ph)`, where `pl` (`ph`) is a sampled low (high) percentile.

    Raises
    ------
    ValueError
        Illegal arguments.
    """
    _valid_low_high_percentiles(pmin) or _raise(ValueError(pmin))
    _valid_low_high_percentiles(pmax) or _raise(ValueError(pmax))
    pmin[1] < pmax[0] or _raise(ValueError())
    return lambda: (np.random.uniform(*pmin), np.random.uniform(*pmax))
Пример #11
0
    def prepare(self, batch_size, n_epochs):
        '''
        this function must be called after the Dataset instance is created
        '''
        if self.prepared == True:
            return self.training_pair_num

        os.path.exists(self.train_lr_path) or _raise(
            Exception('lr training data path doesn\'t exist : %s' %
                      self.train_lr_path))
        os.path.exists(self.train_hr_path) or _raise(
            Exception('hr training data path doesn\'t exist : %s' %
                      self.train_hr_path))
        if (self.hasMR) and (not os.path.exists(self.train_mr_path)):
            raise Exception('mr training data path doesn\'t exist : %s' %
                            self.train_mr_path)
        if self.hasValidation and (not os.path.exists(self.valid_lr_path)):
            raise Exception('validation data path doesn\'t exist : %s' %
                            self.valid_lr_path)

        self.training_pair_num = self._load_training_data()

        # if self.test_data_split >= self.training_pair_num:
        #     self.test_data_split = self.training_pair_num // 2

        self.batch_size = batch_size
        self.n_epochs = n_epochs

        self.prepared = True

        print('HR dataset: %s\nLR dataset: %s' % (str(
            self.training_data_hr.shape), str(self.training_data_lr.shape)))
        if self.hasMR:
            print('MR dataset: %s' % str(self.training_data_mr.shape))
        print()
        return self.training_pair_num - self.test_data_split
Пример #12
0
def create_patches(
        raw_data,
        patch_size,
        n_patches_per_image,
        patch_axes    = None,
        save_file     = None,
        transforms    = None,
        patch_filter  = no_background_patches(),
        normalization = norm_percentiles(),
        shuffle       = True,
        verbose       = True,
    ):
    """Create normalized training data to be used for neural network training.

    Parameters
    ----------
    raw_data : :class:`RawData`
        Object that yields matching pairs of raw images.
    patch_size : tuple
        Shape of the patches to be extraced from raw images.
        Must be compatible with the number of dimensions and axes of the raw images.
        As a general rule, use a power of two along all XYZT axes, or at least divisible by 8.
    n_patches_per_image : int
        Number of patches to be sampled/extracted from each raw image pair (after transformations, see below).
    patch_axes : str or None
        Axes of the extracted patches. If ``None``, will assume to be equal to that of transformed raw data.
    save_file : str or None
        File name to save training data to disk in ``.npz`` format (see :func:`csbdeep.io.save_training_data`).
        If ``None``, data will not be saved.
    transforms : list or tuple, optional
        List of :class:`Transform` objects that apply additional transformations to the raw images.
        This can be used to augment the set of raw images (e.g., by including rotations).
        Set to ``None`` to disable. Default: ``None``.
    patch_filter : function, optional
        Function to determine for each image pair which patches are eligible to be extracted
        (default: :func:`no_background_patches`). Set to ``None`` to disable.
    normalization : function, optional
        Function that takes arguments `(patches_x, patches_y, x, y, mask, channel)`, whose purpose is to
        normalize the patches (`patches_x`, `patches_y`) extracted from the associated raw images
        (`x`, `y`, with `mask`; see :class:`RawData`). Default: :func:`norm_percentiles`.
    shuffle : bool, optional
        Randomly shuffle all extracted patches.
    verbose : bool, optional
        Display overview of images, transforms, etc.

    Returns
    -------
    tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`, str)
        Returns a tuple (`X`, `Y`, `axes`) with the normalized extracted patches from all (transformed) raw images
        and their axes.
        `X` is the array of patches extracted from source images with `Y` being the array of corresponding target patches.
        The shape of `X` and `Y` is as follows: `(n_total_patches, n_channels, ...)`.
        For single-channel images, `n_channels` will be 1.

    Raises
    ------
    ValueError
        Various reasons.

    Example
    -------
    >>> raw_data = RawData.from_folder(basepath='data', source_dirs=['source1','source2'], target_dir='GT', axes='ZYX')
    >>> X, Y, XY_axes = create_patches(raw_data, patch_size=(32,128,128), n_patches_per_image=16)

    Todo
    ----
    - Save created patches directly to disk using :class:`numpy.memmap` or similar?
      Would allow to work with large data that doesn't fit in memory.

    """
    ## images and transforms
    if transforms is None:
        transforms = []
    transforms = list(transforms)
    if patch_axes is not None:
        transforms.append(permute_axes(patch_axes))
    if len(transforms) == 0:
        transforms.append(Transform.identity())


    image_pairs, n_raw_images = raw_data.generator(), raw_data.size
    tf = Transform(*zip(*transforms)) # convert list of Transforms into Transform of lists
    image_pairs = compose(*tf.generator)(image_pairs) # combine all transformations with raw images as input
    n_transforms = np.prod(tf.size)
    n_images = n_raw_images * n_transforms
    n_patches = n_images * n_patches_per_image
    n_required_memory_bytes = 2 * n_patches*np.prod(patch_size) * 4

    ## memory check
    _memory_check(n_required_memory_bytes)

    ## summary
    if verbose:
        print('='*66)
        print('%5d raw images x %4d transformations   = %5d images' % (n_raw_images,n_transforms,n_images))
        print('%5d images     x %4d patches per image = %5d patches in total' % (n_images,n_patches_per_image,n_patches))
        print('='*66)
        print('Input data:')
        print(raw_data.description)
        print('='*66)
        print('Transformations:')
        for t in transforms:
            print('{t.size} x {t.name}'.format(t=t))
        print('='*66)
        print('Patch size:')
        print(" x ".join(str(p) for p in patch_size))
        print('=' * 66)

    sys.stdout.flush()

    ## sample patches from each pair of transformed raw images
    X = np.empty((n_patches,)+tuple(patch_size),dtype=np.float32)
    Y = np.empty_like(X)

    for i, (x,y,_axes,mask) in tqdm(enumerate(image_pairs),total=n_images,disable=(not verbose)):
        if i >= n_images:
            warnings.warn('more raw images (or transformations thereof) than expected, skipping excess images.')
            break
        if i==0:
            axes = axes_check_and_normalize(_axes,len(patch_size))
            channel = axes_dict(axes)['C']
        # checks
        # len(axes) >= x.ndim or _raise(ValueError())
        axes == axes_check_and_normalize(_axes) or _raise(ValueError('not all images have the same axes.'))
        x.shape == y.shape or _raise(ValueError())
        mask is None or mask.shape == x.shape or _raise(ValueError())
        (channel is None or (isinstance(channel,int) and 0<=channel<x.ndim)) or _raise(ValueError())
        channel is None or patch_size[channel]==x.shape[channel] or _raise(ValueError('extracted patches must contain all channels.'))

        _Y,_X = sample_patches_from_multiple_stacks((y,x), patch_size, n_patches_per_image, mask, patch_filter)

        s = slice(i*n_patches_per_image,(i+1)*n_patches_per_image)
        X[s], Y[s] = normalization(_X,_Y, x,y,mask,channel)

    if shuffle:
        shuffle_inplace(X,Y)

    axes = 'SC'+axes.replace('C','')
    if channel is None:
        X = np.expand_dims(X,1)
        Y = np.expand_dims(Y,1)
    else:
        X = np.moveaxis(X, 1+channel, 1)
        Y = np.moveaxis(Y, 1+channel, 1)

    if save_file is not None:
        print('Saving data to %s.' % str(Path(save_file)))
        save_training_data(save_file, X, Y, axes)

    return X,Y,axes
Пример #13
0
def anisotropic_distortions(
        subsample,
        psf,
        psf_axes       = None,
        poisson_noise  = False,
        gauss_sigma    = 0,
        subsample_axis = 'X',
        yield_target   = 'source',
        crop_threshold = 0.2,
    ):
    """Simulate anisotropic distortions.

    Modify the first image (obtained from input generator) along one axis to mimic the
    distortions that typically occur due to low resolution along the Z axis.
    Note that the modified image is finally upscaled to obtain the same resolution
    as the unmodified input image and is yielded as the 'source' image (see :class:`RawData`).
    The mask from the input generator is simply passed through.

    The following operations are applied to the image (in order):

    1. Convolution with PSF
    2. Poisson noise
    3. Gaussian noise
    4. Subsampling along ``subsample_axis``
    5. Upsampling along ``subsample_axis`` (to former size).


    Parameters
    ----------
    subsample : float
        Subsampling factor to mimic distortions along Z.
    psf : :class:`numpy.ndarray` or None
        Point spread function (PSF) that is supposed to mimic blurring
        of the microscope due to reduced axial resolution. Set to ``None`` to disable.
    psf_axes : str or None
        Axes of the PSF. If ``None``, psf axes are assumed to be the same as of the image
        that it is applied to.
    poisson_noise : bool
        Flag to indicate whether Poisson noise should be applied to the image.
    gauss_sigma : float
        Standard deviation of white Gaussian noise to be added to the image.
    subsample_axis : str
        Subsampling image axis (default X).
    yield_target : str
        Which image from the input generator should be yielded by the generator ('source' or 'target').
        If 'source', the unmodified input/source image (from which the distorted image is computed)
        is yielded as the target image. If 'target', the target image from the input generator is simply
        passed through.
    crop_threshold : float
        The subsample factor must evenly divide the image size along the subsampling axis to prevent
        potential image misalignment. If this is not the case the subsample factors are
        modified and the raw image may be cropped along the subsampling axis
        up to a fraction indicated by `crop_threshold`.

    Returns
    -------
    Transform
        Returns a :class:`Transform` object intended to be used with :func:`create_patches`.

    Raises
    ------
    ValueError
        Various reasons.

    """
    zoom_order = 1

    (np.isscalar(subsample) and subsample >= 1) or _raise(ValueError('subsample must be >= 1'))
    _subsample = subsample

    subsample_axis = axes_check_and_normalize(subsample_axis)
    len(subsample_axis)==1 or _raise(ValueError())

    psf is None or isinstance(psf,np.ndarray) or _raise(ValueError())
    if psf_axes is not None:
        psf_axes = axes_check_and_normalize(psf_axes)

    0 < crop_threshold < 1 or _raise(ValueError())

    yield_target in ('source','target') or _raise(ValueError())

    if psf is None and yield_target == 'source':
        warnings.warn(
            "It is strongly recommended to use an appropriate PSF to "
            "mimic the optical effects of the microscope. "
            "We found that training with synthesized anisotropic images "
            "that were created without a PSF "
            "can sometimes lead to unwanted artifacts in the reconstructed images."
        )


    def _make_normalize_data(axes_in):
        """Move X to front of image."""
        axes_in  = axes_check_and_normalize(axes_in)
        axes_out = subsample_axis
        # (a in axes_in for a in 'XY') or _raise(ValueError('X and/or Y axis missing.'))
        # add axis in axes_in to axes_out (if it doesn't exist there)
        axes_out += ''.join(a for a in axes_in if a not in axes_out)

        def _normalize_data(data,undo=False):
            if undo:
                return move_image_axes(data, axes_out, axes_in)
            else:
                return move_image_axes(data, axes_in, axes_out)
        return _normalize_data


    def _scale_down_up(data,subsample):
        from scipy.ndimage.interpolation import zoom
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", UserWarning)
            factor = np.ones(data.ndim)
            factor[0] = subsample
            return zoom(zoom(data, 1/factor, order=0),
                                     factor, order=zoom_order)


    def _adjust_subsample(d,s,c):
        """length d, subsample s, tolerated crop loss fraction c"""
        from fractions import Fraction

        def crop_size(n_digits,frac):
            _s = round(s,n_digits)
            _div = frac.denominator
            s_multiple_max = np.floor(d/_s)
            s_multiple = (s_multiple_max//_div)*_div
            # print(n_digits, _s,_div,s_multiple)
            size = s_multiple * _s
            assert np.allclose(size,round(size))
            return size

        def decimals(v,n_digits=None):
            if n_digits is not None:
                v = round(v,n_digits)
            s = str(v)
            assert '.' in s
            decimals = s[1+s.find('.'):]
            return int(decimals), len(decimals)

        s = float(s)
        dec, n_digits = decimals(s)
        frac = Fraction(dec,10**n_digits)
        # a multiple of s that is also an integer number must be
        # divisible by the denominator of the fraction that represents the decimal points

        # round off decimals points if needed
        while n_digits > 0 and (d-crop_size(n_digits,frac))/d > c:
            n_digits -= 1
            frac = Fraction(decimals(s,n_digits)[0], 10**n_digits)

        size = crop_size(n_digits,frac)
        if size == 0 or (d-size)/d > c:
            raise ValueError("subsample factor %g too large (crop_threshold=%g)" % (s,c))

        return round(s,n_digits), int(round(crop_size(n_digits,frac)))


    def _make_divisible_by_subsample(x,size):
        def _split_slice(v):
            return slice(None) if v==0 else slice(v//2,-(v-v//2))
        slices = [slice(None) for _ in x.shape]
        slices[0] = _split_slice(x.shape[0]-size)
        return x[slices]


    def _generator(inputs):
        for img,y,axes,mask in inputs:

            if yield_target == 'source':
                y is None or np.allclose(img,y) or warnings.warn("ignoring 'target' image from input generator")
                target = img
            else:
                target = y

            img.shape == target.shape or _raise(ValueError())

            axes = axes_check_and_normalize(axes)
            _normalize_data = _make_normalize_data(axes)
            # print(axes, img.shape)

            x = img.astype(np.float32, copy=False)

            if psf is not None:
                from scipy.signal import fftconvolve
                # print("blurring with psf")
                _psf = psf.astype(np.float32,copy=False)
                np.min(_psf) >= 0 or _raise(ValueError('psf has negative values.'))
                _psf /= np.sum(_psf)
                if psf_axes is not None:
                    _psf = move_image_axes(_psf, psf_axes, axes, True)
                x.ndim == _psf.ndim or _raise(ValueError('image and psf must have the same number of dimensions.'))

                if 'C' in axes:
                    ch = axes_dict(axes)['C']
                    n_channels = x.shape[ch]
                    # convolve with psf separately for every channel
                    if _psf.shape[ch] == 1:
                        warnings.warn('applying same psf to every channel of the image.')
                    if _psf.shape[ch] in (1,n_channels):
                        x = np.stack([
                            fftconvolve(
                                np.take(x,   i,axis=ch),
                                np.take(_psf,i,axis=ch,mode='clip'),
                                mode='same'
                            )
                            for i in range(n_channels)
                        ],axis=ch)
                    else:
                        raise ValueError('number of psf channels (%d) incompatible with number of image channels (%d).' % (_psf.shape[ch],n_channels))
                else:
                    x = fftconvolve(x, _psf, mode='same')

            if bool(poisson_noise):
                # print("apply poisson noise")
                x = np.random.poisson(np.maximum(0,x).astype(np.int)).astype(np.float32)

            if gauss_sigma > 0:
                # print("adding gaussian noise with sigma = ", gauss_sigma)
                noise = np.random.normal(0,gauss_sigma,size=x.shape).astype(np.float32)
                x = np.maximum(0,x+noise)

            if _subsample != 1:
                # print("down and upsampling X by factor %s" % str(_subsample))
                target = _normalize_data(target)
                x      = _normalize_data(x)

                subsample, subsample_size = _adjust_subsample(x.shape[0],_subsample,crop_threshold)
                # print(subsample, subsample_size)
                if _subsample != subsample:
                    warnings.warn('changing subsample from %s to %s' % (str(_subsample),str(subsample)))

                target = _make_divisible_by_subsample(target,subsample_size)
                x      = _make_divisible_by_subsample(x,     subsample_size)
                x      = _scale_down_up(x,subsample)

                assert x.shape == target.shape, (x.shape, target.shape)

                target = _normalize_data(target,undo=True)
                x      = _normalize_data(x,     undo=True)

            yield x, target, axes, mask


    return Transform('Anisotropic distortion (along %s axis)' % subsample_axis, _generator, 1)
Пример #14
0
 def _generator(inputs):
     for x, y, axes, mask in inputs:
         axes = axes_check_and_normalize(axes)
         len(axes) == len(slices) or _raise(ValueError())
         yield x[slices], y[slices], axes, (mask[slices] if mask is not None else None)
Пример #15
0
    def _load_training_data(self, shuffle=True):
        def _shuffle_in_unison(arr1, arr2):
            """shuffle elements in arr1 and arr2 in unison along the leading dimension 
            Params:
                -arr1, arr2: np.ndarray
                    must be in the same size in the leading dimension
            """
            assert (len(arr1) == len(arr2))
            new_idx = np.random.permutation(len(arr1))
            return arr1[new_idx], arr2[new_idx]

        def _shuffle_index(len):
            new_index = np.random.permutation(len)
            return new_index

        def _get_im_blocks(path,
                           block_size,
                           dtype=np.float32,
                           transform=None,
                           keep_all=True,
                           keep_list=None,
                           **kwargs):
            """laod image volume and crop into small blocks for training dataset.
            Params:
                -block_size : [depth height width channels]  
                -transform : transformation function applied to the loaded image
                -keep_all: boolean, whether to kepp all the blocks
                -keep_list : numpy list, index of block to be kept, useful when keep_all is False
                -kwargs : key-word args for transform fn

            return images in shape [n_images, depth, height, width, channels]
            """

            depth, height, width, _ = block_size  # the desired image block size
            blocks = []
            im_list = sorted(
                tl.files.load_file_list(path=path,
                                        regx='.*.tif',
                                        printable=False))
            # im_list = sorted(tl.files.load_file_list(path=path, regx='.*.mat', printable=False))

            block_idx = -1
            idx_saved = []
            keep_list_cursor = 0

            for im_file in im_list:
                im = load_im(path + im_file, normalize=self.normalize)

                if (im.dtype != dtype):
                    im = im.astype(dtype, casting='unsafe')
                print('\r%s : %s ' % ((path + im_file), str(im.shape)), end='')

                if transform is not None:
                    im = transform(im, **kwargs)
                    print('transfrom: %s' % str(im.shape), end='')
                d_real, h_real, w_real, _ = im.shape  # the actual size of the image
                max_val = np.percentile(im, 98)

                for d in range(0, d_real, depth):
                    for h in range(0, h_real, height):
                        for w in range(0, w_real, width):
                            if d + depth > d_real or h + height > h_real or w + width > w_real:
                                # out of image bounds
                                continue

                            block = im[d:(d + depth), h:(h + height),
                                       w:(w + width), :]
                            block_idx += 1

                            if not keep_all:
                                if keep_list is None:
                                    if (np.max(block) > max_val * 0.2):
                                        blocks.append(block)
                                        idx_saved.append(block_idx)
                                else:
                                    if (keep_list_cursor < len(keep_list)
                                            and block_idx
                                            == keep_list[keep_list_cursor]):
                                        blocks.append(block)
                                        keep_list_cursor += 1
                            else:
                                blocks.append(block)

            print('\nload %d of size %s from %s' %
                  (len(blocks), str(block_size), path))
            blocks = np.asarray(blocks)

            keep_list = idx_saved if keep_list is None else keep_list
            return blocks, keep_list

        # self._check_inputs()
        self.training_data_hr, valid_indices = _get_im_blocks(
            self.train_hr_path,
            self.hr_size,
            transform=self.transforms[1],
            keep_all=self.keep_all_blocks)
        len(self.training_data_hr) != 0 or _raise(
            Exception(
                "none of the HRs have been loaded, please check the image size ({} desired)"
                .format(str(self.hr_size))))

        #self.training_data_lr = _get_im_blocks(self.train_lr_path, self.hr_size, self.dtype, transform=interpolate3d)
        self.training_data_lr, _ = _get_im_blocks(
            self.train_lr_path,
            self.lr_size,
            keep_all=self.keep_all_blocks,
            keep_list=valid_indices,
            transform=self.transforms[0],
            **self.transforms_args)
        len(self.training_data_lr) != 0 or _raise(
            Exception(
                "none of the LRs have been loaded, please check the image size ({} desired)"
                .format(str(self.lr_size))))
        self.training_data_hr.shape[
            0] == self.training_data_lr.shape[0] or _raise(
                ValueError("num of LR blocks and HR blocks not equal"))

        self.test_data_split = int(len(self.training_data_hr) * 0.2)
        if self.hasTest:
            self.test_data_lr, _ = _get_im_blocks(self.test_lr_path,
                                                  self.lr_size,
                                                  transform=self.transforms[0],
                                                  **self.transforms_args)
            self.test_data_hr, _ = _get_im_blocks(self.test_hr_path,
                                                  self.hr_size,
                                                  transform=self.transforms[1])
            self.test_data_split = 0

        if self.hasMR:
            self.training_data_mr, _ = _get_im_blocks(
                self.train_mr_path,
                self.mr_size,
                keep_all=self.keep_all_blocks,
                keep_list=valid_indices)
            self.training_data_mr.shape[0] == self.training_data_lr.shape[
                0] or _raise(
                    ValueError("num of MR blocks and LR blocks not equal"))
            if self.hasTest:
                self.test_data_mr, _ = _get_im_blocks(self.test_mr_path,
                                                      self.mr_size)

        if self.hasValidation:
            self.valid_data_lr, _ = _get_im_blocks(
                self.valid_lr_path,
                self.lr_size,
                transform=self.transforms[0],
                **self.transforms_args)
        # self.plchdr_lr_valid = tf.placeholder(self.dtype, shape=self.valid_data_lr.shape, name='valid_lr')

        return self.training_data_hr.shape[0]
Пример #16
0
Файл: eval.py Проект: xinDW/DVSR
def evaluate_whole(epoch, load_graph_from_pb=False, half_precision_infer=False, use_cpu=False, large_volume=False, save_pb=True, save_activations=False):
    
    start_time = time.time()
    
    device_tag = 'gpu' if not use_cpu else 'cpu'
    graph_file_tag = '%s_%dx%dx%d_%s' % (label.replace("/", "-"), lr_size[0], lr_size[1], lr_size[2], device_tag)

    if load_graph_from_pb:
        graph_file  = '%s_half-precision.pb' % (graph_file_tag) if half_precision_infer else graph_file_tag + '.pb'
        model_path = os.path.join(pb_file_dir, graph_file)
        os.path.exists(model_path) or _raise(ValueError('%s doesn\'t exist' % model_path))

        import_name = "hp"
        sess = load_graph(model_path, import_name=import_name, verbose=False)

        LR   = sess.graph.get_tensor_by_name("%s/%s:0" % (import_name, input_op_name))
        net  = sess.graph.get_tensor_by_name("%s/%s:0" % (import_name, output_op_name))

    else:
        
        sess, net, LR = build_model_and_load_npz(epoch, use_cpu=use_cpu, save_pb=save_pb)
        if save_pb:
            save_as_pb(graph_file_tag, sess=sess)
            return

    exists_or_mkdir(save_dir)
    model      = Model(net, sess, LR)

    block_size = lr_size[0:3]
      
    if large_volume:
        start_time = time.time()
        predictor = LargeDataPredictor(data_path=valid_lr_img_path, 
            saving_path=save_dir, 
            factor=factor, 
            model=model, 
            block_size=block_size,
            overlap=overlap,
            half_precision=half_precision_infer)
        predictor.predict()
        print('time elapsed : %.2fs' % (time.time() - start_time))

    else:  
        valid_lr_imgs = get_file_list(path=valid_lr_img_path, regx='.*.tif') 
        predictor = Predictor(factor=factor, model=model, half_precision=half_precision_infer)

        for _, im_file in enumerate(valid_lr_imgs):
            start_time = time.time()
            
            print('='*66)
            print('predicting on %s ' % os.path.join(valid_lr_img_path, im_file) )
            
            im = imageio.volread(os.path.join(valid_lr_img_path, im_file))
            if archi1 is None and archi2 == 'unet':
                im = interpolate3d(im, factor=config.factor)

            # if (thres > 100):
            #     sr = predictor.predict(im, block_size, overlap, normalization='fixed', max_v=thres / 100.)
            # else:
            #     sr = predictor.predict(im, block_size, overlap, normalization='percentile', low=0.2, high=thres )
            
            sr = predictor.predict(im, block_size, overlap, normalization=normalization)

            print('time elapsed : %.4f' % (time.time() - start_time))
            # imageio.volwrite(os.path.join(save_dir, ('SR-thres%s-' % str(thres).replace('.', 'p')) + im_file), sr)
            try:
                imageio.volwrite(os.path.join(save_dir, 'SR-' + im_file), sr)

            except ValueError: # data too large for standard TIFF file
                short_name = im_file.split('.tif')[0].replace('.', '_')
                slice_save_dir   = os.path.join(save_dir, 'SR-' + short_name)
                exists_or_mkdir(slice_save_dir)

                for d, slice_ in enumerate(sr):
                    name = os.path.join(slice_save_dir, '%05d.tif' % (d + 1) )
                    imageio.imwrite(name, slice_) 
            
                
    model.recycle()
Пример #17
0
 def __check_prepared(self):
     self.prepared or _raise(Exception('Dataset.prepare() must be called'))
Пример #18
0
Файл: eval.py Проект: xinDW/DVSR
def build_model_and_load_npz(epoch, use_cpu=False, save_pb=False):
    
    epoch = 'best' if epoch == 0 else epoch
    # # search for ckpt files 
    def _search_for_ckpt_npz(file_dir, tags):
        filelist = os.listdir(checkpoint_dir)
        for filename in filelist:
            if '.npz' in filename:
                if all(tag in filename for tag in tags):
                    return filename
        return None

    if (archi1 is not None):
        resolve_ckpt_file = _search_for_ckpt_npz(checkpoint_dir, ['resolve', str(epoch)])
        interp_ckpt_file  = _search_for_ckpt_npz(checkpoint_dir, ['interp', str(epoch)])
       
        (resolve_ckpt_file is not None and interp_ckpt_file is not None) or _raise(Exception('checkpoint file not found'))

    else:
        #checkpoint_dir = "checkpoint/" 
        #ckpt_file = "brain_conv3_epoch1000_rdn.npz"
        ckpt_file = _search_for_ckpt_npz(checkpoint_dir, [str(epoch)])
        
        ckpt_file is not None or _raise(Exception('checkpoint file not found'))
    

    #======================================
    # build the model
    #======================================
    
    if use_cpu is False:
        device_str = '/gpu:%d' % device_id
    else:
        device_str = '/cpu:0'

    LR = tf.placeholder(tf.float32, [1] + lr_size)
    if (archi1 is not None):
        # if ('resolve_first' in archi):        
        with tf.device(device_str):
            if archi1 =='dbpn':   
                resolver = DBPN(LR, upscale=False, name="net_s1")
            elif archi1 =='denoise': 
                resolver = denoise_net(LR, name="net_s1")
            elif archi1 =='unet': 
                resolver = unet3d(LR, name="net_s1")
            else:
                _raise(ValueError())
            
            if archi2 =='rdn':
                interpolator = res_dense_net(resolver.outputs, factor=factor, conv_kernel=conv_kernel, bn=using_batch_norm, is_train=False, name="net_s2")
                net = interpolator
            else:
                _raise(ValueError())

    else : 
        archi = archi2
        with tf.device(device_str):
            if archi =='rdn':
                net = res_dense_net(LR, factor=factor, bn=using_batch_norm, conv_kernel=conv_kernel, name="net_s2")
            elif archi =='unet':
                # net = unet3d(LR, upscale=False)
                net = unet_care(LR)
            elif archi =='dbpn':
                net = DBPN(LR, upscale=True)
            else:
                raise Exception('unknow architecture: %s' % archi)

    net.print_params(details=False)
    
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    if (archi1 is None):
        tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/' + ckpt_file, network=net)
    else:
        tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/' + resolve_ckpt_file, network=resolver)
        tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/' + interp_ckpt_file, network=interpolator)

    return sess, net, LR
Пример #19
0
def create_patches_reduced_target(
        raw_data,
        patch_size,
        n_patches_per_image,
        reduction_axes,
        target_axes = None, # TODO: this should rather be part of RawData and also exposed to transforms
        **kwargs
    ):
    """Create normalized training data to be used for neural network training.

    In contrast to :func:`create_patches`, it is assumed that the target image has reduced
    dimensionality (i.e. size 1) along one or several axes (`reduction_axes`).

    Parameters
    ----------
    raw_data : :class:`RawData`
        See :func:`create_patches`.
    patch_size : tuple
        See :func:`create_patches`.
    n_patches_per_image : int
        See :func:`create_patches`.
    reduction_axes : str
        Axes where the target images have a reduced dimension (i.e. size 1) compared to the source images.
    target_axes : str
        Axes of the raw target images. If ``None``, will be assumed to be equal to that of the raw source images.
    kwargs : dict
        Additional parameters as in :func:`create_patches`.

    Returns
    -------
    tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`, str)
        See :func:`create_patches`. Note that the shape of the target data will be 1 along all reduction axes.

    """
    reduction_axes = axes_check_and_normalize(reduction_axes,disallowed='S')

    transforms = kwargs.get('transforms')
    if transforms is None:
        transforms = []
    transforms = list(transforms)
    transforms.insert(0,broadcast_target(target_axes))
    kwargs['transforms'] = transforms

    save_file = kwargs.pop('save_file',None)

    if any(s is None for s in patch_size):
        patch_axes = kwargs.get('patch_axes')
        if patch_axes is not None:
            _transforms = list(transforms)
            _transforms.append(permute_axes(patch_axes))
        else:
            _transforms = transforms
        tf = Transform(*zip(*_transforms))
        image_pairs = compose(*tf.generator)(raw_data.generator())
        x,y,axes,mask = next(image_pairs) # get the first entry from the generator
        patch_size = list(patch_size)
        for i,(a,s) in enumerate(zip(axes,patch_size)):
            if s is not None: continue
            a in reduction_axes or _raise(ValueError("entry of patch_size is None for non reduction axis %s." % a))
            patch_size[i] = x.shape[i]
        patch_size = tuple(patch_size)
        del x,y,axes,mask

    X,Y,axes = create_patches (
        raw_data            = raw_data,
        patch_size          = patch_size,
        n_patches_per_image = n_patches_per_image,
        **kwargs
    )

    ax = axes_dict(axes)
    for a in reduction_axes:
        a in axes or _raise(ValueError("reduction axis %d not present in extracted patches" % a))
        n_dims = Y.shape[ax[a]]
        if n_dims == 1:
            warnings.warn("extracted target patches already have dimensionality 1 along reduction axis %s." % a)
        else:
            t = np.take(Y,(1,),axis=ax[a])
            Y = np.take(Y,(0,),axis=ax[a])
            i = np.random.choice(Y.size,size=100)
            if not np.all(t.flat[i]==Y.flat[i]):
                warnings.warn("extracted target patches vary along reduction axis %s." % a)

    if save_file is not None:
        print('Saving data to %s.' % str(Path(save_file)))
        save_training_data(save_file, X, Y, axes)

    return X,Y,axes
Пример #20
0
    def _generator(inputs):
        for img,y,axes,mask in inputs:

            if yield_target == 'source':
                y is None or np.allclose(img,y) or warnings.warn("ignoring 'target' image from input generator")
                target = img
            else:
                target = y

            img.shape == target.shape or _raise(ValueError())

            axes = axes_check_and_normalize(axes)
            _normalize_data = _make_normalize_data(axes)
            # print(axes, img.shape)

            x = img.astype(np.float32, copy=False)

            if psf is not None:
                from scipy.signal import fftconvolve
                # print("blurring with psf")
                _psf = psf.astype(np.float32,copy=False)
                np.min(_psf) >= 0 or _raise(ValueError('psf has negative values.'))
                _psf /= np.sum(_psf)
                if psf_axes is not None:
                    _psf = move_image_axes(_psf, psf_axes, axes, True)
                x.ndim == _psf.ndim or _raise(ValueError('image and psf must have the same number of dimensions.'))

                if 'C' in axes:
                    ch = axes_dict(axes)['C']
                    n_channels = x.shape[ch]
                    # convolve with psf separately for every channel
                    if _psf.shape[ch] == 1:
                        warnings.warn('applying same psf to every channel of the image.')
                    if _psf.shape[ch] in (1,n_channels):
                        x = np.stack([
                            fftconvolve(
                                np.take(x,   i,axis=ch),
                                np.take(_psf,i,axis=ch,mode='clip'),
                                mode='same'
                            )
                            for i in range(n_channels)
                        ],axis=ch)
                    else:
                        raise ValueError('number of psf channels (%d) incompatible with number of image channels (%d).' % (_psf.shape[ch],n_channels))
                else:
                    x = fftconvolve(x, _psf, mode='same')

            if bool(poisson_noise):
                # print("apply poisson noise")
                x = np.random.poisson(np.maximum(0,x).astype(np.int)).astype(np.float32)

            if gauss_sigma > 0:
                # print("adding gaussian noise with sigma = ", gauss_sigma)
                noise = np.random.normal(0,gauss_sigma,size=x.shape).astype(np.float32)
                x = np.maximum(0,x+noise)

            if _subsample != 1:
                # print("down and upsampling X by factor %s" % str(_subsample))
                target = _normalize_data(target)
                x      = _normalize_data(x)

                subsample, subsample_size = _adjust_subsample(x.shape[0],_subsample,crop_threshold)
                # print(subsample, subsample_size)
                if _subsample != subsample:
                    warnings.warn('changing subsample from %s to %s' % (str(_subsample),str(subsample)))

                target = _make_divisible_by_subsample(target,subsample_size)
                x      = _make_divisible_by_subsample(x,     subsample_size)
                x      = _scale_down_up(x,subsample)

                assert x.shape == target.shape, (x.shape, target.shape)

                target = _normalize_data(target,undo=True)
                x      = _normalize_data(x,     undo=True)

            yield x, target, axes, mask
Пример #21
0
def evaluate_whole(model,
                   half_precision_infer=False,
                   use_cpu=False,
                   large_volume=False,
                   save_pb=True,
                   save_activations=False):
    if use_cpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    start_time = time.time()

    device_tag = 'gpu' if not use_cpu else 'cpu'

    if model == 'brain':
        sample = 'brain'
        factor = 4
        lr_img_path = 'example_data/brain/LR/'
        save_dir = 'example_data/brain/SR/'
    else:
        sample = 'tubulin'
        factor = 2
        lr_img_path = 'example_data/cell/LR/'
        save_dir = 'example_data/cell/SR/'

    graph_file_tag = '%s_2stage_dbpn+rdn_factor%d_50x50x50_%s' % (
        sample, factor, device_tag)
    graph_file = '%s_half-precision.pb' % (
        graph_file_tag) if half_precision_infer else graph_file_tag + '.pb'

    model_path = os.path.join(pb_file_dir, graph_file)
    os.path.exists(model_path) or _raise(
        ValueError('%s doesn\'t exist' % model_path))

    import_name = "dsp"
    sess = load_graph(model_path, import_name=import_name, verbose=False)

    LR = sess.graph.get_tensor_by_name("%s/%s:0" %
                                       (import_name, input_op_name))
    net = sess.graph.get_tensor_by_name("%s/%s:0" %
                                        (import_name, output_op_name))

    exists_or_mkdir(save_dir)

    model = Model(net, sess, LR)
    block_size = lr_size[0:3]
    overlap = 0.2

    import imageio
    dtype = np.float16 if half_precision_infer else np.float32
    if large_volume:
        start_time = time.time()
        predictor = LargeDataPredictor(data_path=lr_img_path,
                                       saving_path=save_dir,
                                       factor=factor,
                                       model=model,
                                       block_size=block_size,
                                       overlap=overlap,
                                       dtype=dtype)
        predictor.predict()
        print('time elapsed : %.2fs' % (time.time() - start_time))

    else:
        valid_lr_imgs = get_file_list(path=lr_img_path, regx='.*.tif')
        predictor = Predictor(factor=factor, model=model, dtype=dtype)

        for _, im_file in enumerate(valid_lr_imgs):
            start_time = time.time()
            im = imageio.volread(os.path.join(lr_img_path, im_file))

            print('=' * 66)
            print('predicting on %s ' % os.path.join(lr_img_path, im_file))
            sr = predictor.predict(im, block_size, overlap, low=0.2)
            print('time elapsed : %.4f' % (time.time() - start_time))
            imageio.volwrite(os.path.join(save_dir, 'DSP_' + im_file), sr)

    model.recycle()
Пример #22
0
 def predict(self, input):
     len(input.shape) == 3 or _raise(ValueError(''))
     input = input[np.newaxis, ..., np.newaxis]
     feed_dict = {self.plchdr: input}
     out = self.sess.run(self.net, feed_dict)
     return np.squeeze(out, axis=(0, -1))