コード例 #1
0
ファイル: patchlib.py プロジェクト: GRSEB9S/pytools-lib
def patch_gen(vol, patch_size, stride=1, nargout=1):
    """
    NOT VERY WELL TESTED
    generator of patches from volume

    TODO: use .grid() to get sub

    """

    cropped_vol_size = np.array(vol.shape) - np.array(patch_size) + 1
    assert np.all(cropped_vol_size >= 0), \
        "patch size needs to be smaller than volume size"

    # get range subs
    sub = ()
    for cvs in cropped_vol_size:
        sub += (list(range(0, cvs, stride)), )

    # get ndgrid of subs
    ndg = nd.ndgrid(*sub)
    ndg = [f.flat for f in ndg]

    # generator
    slicer = lambda f, g: slice(f[idx], f[idx] + g)
    for idx in range(len(ndg[0])):
        patch_sub = [slicer(f, g) for f, g in zip(ndg, patch_size)]
        if nargout == 1:
            yield vol[patch_sub]
        else:
            yield (vol[patch_sub], patch_sub)
コード例 #2
0
ファイル: patchlib.py プロジェクト: koo616/CT_registration
def patch_gen(vol,
              patch_size,
              stride=1,
              nargout=1,
              rand=False,
              rand_seed=None):
    """
    NOT VERY WELL TESTED
    generator of patches from volume

    TODO: use .grid() to get sub

    """

    # some parameter checking
    if isinstance(stride, int):
        stride = [stride for f in patch_size]
    assert len(vol.shape) == len(patch_size), \
        "vol shape %s and patch size %s do not match dimensions" \
        % (pformat(vol.shape), pformat(patch_size))
    assert len(vol.shape) == len(stride), \
        "vol shape %s and patch stride %s do not match dimensions" \
        % (pformat(vol.shape), pformat(stride))

    cropped_vol_size = np.array(vol.shape) - np.array(patch_size) + 1
    assert np.all(cropped_vol_size >= 0), \
        "patch size needs to be smaller than volume size"

    # get range subs
    sub = ()
    for idx, cvs in enumerate(cropped_vol_size):
        sub += (list(range(0, cvs, stride[idx])), )

    # check the size
    gs = gridsize(vol.shape, patch_size, patch_stride=stride)
    assert [len(f) for f in sub] == list(gs), 'Patch gen side failure'

    # get ndgrid of subs
    ndg = nd.ndgrid(*sub)
    ndg = [f.flat for f in ndg]

    # generator
    rng = list(range(len(ndg[0])))
    if rand:
        if rand_seed is not None:
            random.seed(rand_seed)
        shuffle(rng)

    for idx in rng:
        slicer = lambda f, g: slice(f[idx], f[idx] + g)
        patch_sub = [slicer(f, g) for f, g in zip(ndg, patch_size)]
        # print(patch_sub)
        if nargout == 1:
            yield vol[patch_sub]
        else:
            yield (vol[patch_sub], patch_sub)
コード例 #3
0
ファイル: patchlib.py プロジェクト: GRSEB9S/pytools-lib
def grid(vol_size, patch_size, patch_stride=1, start_sub=0, nargout=1, grid_type='idx'):
    """
    grid of patch starting points for nd volume that fit into given volume size

    The index is in the given volume. If the volume gets cropped as part of the function and you
    want a linear indexing into the new volume size, use
    >> newidx = ind2ind(new_vol_size, vol_size, idx)
    new_vol_size can be passed by the current function, see below.

    Parameters:
        vol_size (numpy vector): the size of the input volume
        patch_size (numpy vector): the size of the patches
        patch_stride (int or numpy vector, optional): stride (separation) in each dimension.
            default: 1
        start_sub (int or numpy vector, optional): the volume location where patches start
            This essentially means that the volume will be cropped starting at that location.
            e.g. if startSub is [2, 2], then only vol(2:end, 2:end) will be included.
            default: 0
        nargout (int, 1,2 or 3): optionally output new (cropped) volume size and the grid size
            return the idx array only if nargout is 1, or (idx, new_vol_size) if nargout is 2,
            or (idx, new_vol_size, grid_size) if nargout is 3
        grid_type ('idx' or 'sub', optional): how to describe the grid, in linear index (idx)
            or nd subscripts ('sub'). sub will be a nb_patches x nb_dims ndarray. This is
            equivalent to sub = ind2sub(vol_size, idx), but is done faster inside this function.
            [TODO: or it was in MATLAB, this might not be true in python anymore]

    Returns:
        idx nd array only if nargout is 1, or (idx, new_vol_size) if nargout is 2,
            or (idx, new_vol_size, grid_size) if nargout is 3

    See also:
        gridsize()

    Contact:
        {adalca,klbouman}@csail.mit.edu
    """

    # parameter checking
    assert grid_type in ('idx', 'sub')
    if not isinstance(vol_size, np.ndarray):
        vol_size = np.array(vol_size, 'int')
    if not isinstance(patch_size, np.ndarray):
        patch_size = np.array(patch_size, 'int')
    nb_dims = len(patch_size)   # number of dimensions
    if isinstance(patch_stride, int):
        patch_stride = np.repeat(patch_stride, nb_dims).astype('int')
    if isinstance(start_sub, int):
        start_sub = np.repeat(start_sub, nb_dims).astype('int')

    # get the grid data
    [grid_size, new_vol_size] = gridsize(vol_size, patch_size,
                                         patch_stride=patch_stride,
                                         start_sub=start_sub,
                                         nargout=2)

    # compute grid linear index
    # prepare the sample grid in each dimension
    xvec = ()
    for idx in range(nb_dims):
        volend = new_vol_size[idx] + start_sub[idx] - patch_size[idx] + 1
        locs = list(range(start_sub[idx], volend, patch_stride[idx]))
        xvec += (locs, )
        assert any((locs[-1] + patch_size - 1) == (new_vol_size + start_sub - 1))

    # get the nd grid
    # if want subs, this is the faster way to compute in MATLAB (rather than ind -> ind2sub)
    # TODO: need to investigate for python
    idx = nd.ndgrid(*xvec)
    if grid_type == 'idx':
        # if want index, this is the faster way to compute (rather than sub -> sub2ind
        all_idx = np.array(list(range(0, np.prod(vol_size))))
        all_idx = np.reshape(all_idx, vol_size)
        idx = all_idx[idx]

    if nargout == 1:
        return idx
    elif nargout == 2:
        return (idx, new_vol_size)
    else:
        return (idx, new_vol_size, grid_size)
コード例 #4
0
ファイル: patchlib.py プロジェクト: MIDA-group/MultiRegEval
def patch_gen(vol,
              patch_size,
              stride=1,
              nargout=1,
              rand=False,
              rand_seed=None):
    """
    generator of patches from volume

    Parameters:
        vol (numpy array): the n-d volume to be patched
        patch_size (numpy vector): the size of the patches
        patch_stride (int or numpy vector, optional): stride (separation) in each dimension.
            default: 1
        nargout (int, optional): how much to yield
            1 (default: the patch) or 2 (tuple with the patch and volume slices for that patch)
        rand (logical, optional): whether to randomize patch order (default: False)
        rand_seed (number, optional): random seed if randomizing patch order

    TODO: test more...
    TODO: use .grid() to get sub
    """

    # some parameter checking
    if isinstance(stride, int):
        stride = [stride for f in patch_size]
    assert len(vol.shape) == len(patch_size), \
        "vol shape %s and patch size %s do not match dimensions" \
        % (pformat(vol.shape), pformat(patch_size))
    assert len(vol.shape) == len(stride), \
        "vol shape %s and patch stride %s do not match dimensions" \
        % (pformat(vol.shape), pformat(stride))

    cropped_vol_size = np.array(vol.shape) - np.array(patch_size) + 1
    assert np.all(cropped_vol_size >= 0), \
        "patch size needs to be smaller than volume size"

    # get range subs
    sub = ()
    for idx, cvs in enumerate(cropped_vol_size):
        sub += (list(range(0, cvs, stride[idx])), )

    # check the size
    gs = gridsize(vol.shape, patch_size, patch_stride=stride)
    assert [len(f) for f in sub] == list(gs), 'Patch gen side failure'

    # get ndgrid of subs
    ndg = nd.ndgrid(*sub)
    ndg = [f.flat for f in ndg]

    # generator
    rng = list(range(len(ndg[0])))
    if rand:
        if rand_seed is not None:
            random.seed(rand_seed)
        shuffle(rng)

    for idx in rng:
        slicer = lambda f, g: slice(f[idx], f[idx] + g)
        patch_sub = [slicer(f, g) for f, g in zip(ndg, patch_size)]
        # print(patch_sub)
        if nargout == 1:
            yield vol[patch_sub]
        else:
            yield (vol[patch_sub], patch_sub)