def patch_gen(vol, patch_size, stride=1, nargout=1): """ NOT VERY WELL TESTED generator of patches from volume TODO: use .grid() to get sub """ cropped_vol_size = np.array(vol.shape) - np.array(patch_size) + 1 assert np.all(cropped_vol_size >= 0), \ "patch size needs to be smaller than volume size" # get range subs sub = () for cvs in cropped_vol_size: sub += (list(range(0, cvs, stride)), ) # get ndgrid of subs ndg = nd.ndgrid(*sub) ndg = [f.flat for f in ndg] # generator slicer = lambda f, g: slice(f[idx], f[idx] + g) for idx in range(len(ndg[0])): patch_sub = [slicer(f, g) for f, g in zip(ndg, patch_size)] if nargout == 1: yield vol[patch_sub] else: yield (vol[patch_sub], patch_sub)
def patch_gen(vol, patch_size, stride=1, nargout=1, rand=False, rand_seed=None): """ NOT VERY WELL TESTED generator of patches from volume TODO: use .grid() to get sub """ # some parameter checking if isinstance(stride, int): stride = [stride for f in patch_size] assert len(vol.shape) == len(patch_size), \ "vol shape %s and patch size %s do not match dimensions" \ % (pformat(vol.shape), pformat(patch_size)) assert len(vol.shape) == len(stride), \ "vol shape %s and patch stride %s do not match dimensions" \ % (pformat(vol.shape), pformat(stride)) cropped_vol_size = np.array(vol.shape) - np.array(patch_size) + 1 assert np.all(cropped_vol_size >= 0), \ "patch size needs to be smaller than volume size" # get range subs sub = () for idx, cvs in enumerate(cropped_vol_size): sub += (list(range(0, cvs, stride[idx])), ) # check the size gs = gridsize(vol.shape, patch_size, patch_stride=stride) assert [len(f) for f in sub] == list(gs), 'Patch gen side failure' # get ndgrid of subs ndg = nd.ndgrid(*sub) ndg = [f.flat for f in ndg] # generator rng = list(range(len(ndg[0]))) if rand: if rand_seed is not None: random.seed(rand_seed) shuffle(rng) for idx in rng: slicer = lambda f, g: slice(f[idx], f[idx] + g) patch_sub = [slicer(f, g) for f, g in zip(ndg, patch_size)] # print(patch_sub) if nargout == 1: yield vol[patch_sub] else: yield (vol[patch_sub], patch_sub)
def grid(vol_size, patch_size, patch_stride=1, start_sub=0, nargout=1, grid_type='idx'): """ grid of patch starting points for nd volume that fit into given volume size The index is in the given volume. If the volume gets cropped as part of the function and you want a linear indexing into the new volume size, use >> newidx = ind2ind(new_vol_size, vol_size, idx) new_vol_size can be passed by the current function, see below. Parameters: vol_size (numpy vector): the size of the input volume patch_size (numpy vector): the size of the patches patch_stride (int or numpy vector, optional): stride (separation) in each dimension. default: 1 start_sub (int or numpy vector, optional): the volume location where patches start This essentially means that the volume will be cropped starting at that location. e.g. if startSub is [2, 2], then only vol(2:end, 2:end) will be included. default: 0 nargout (int, 1,2 or 3): optionally output new (cropped) volume size and the grid size return the idx array only if nargout is 1, or (idx, new_vol_size) if nargout is 2, or (idx, new_vol_size, grid_size) if nargout is 3 grid_type ('idx' or 'sub', optional): how to describe the grid, in linear index (idx) or nd subscripts ('sub'). sub will be a nb_patches x nb_dims ndarray. This is equivalent to sub = ind2sub(vol_size, idx), but is done faster inside this function. [TODO: or it was in MATLAB, this might not be true in python anymore] Returns: idx nd array only if nargout is 1, or (idx, new_vol_size) if nargout is 2, or (idx, new_vol_size, grid_size) if nargout is 3 See also: gridsize() Contact: {adalca,klbouman}@csail.mit.edu """ # parameter checking assert grid_type in ('idx', 'sub') if not isinstance(vol_size, np.ndarray): vol_size = np.array(vol_size, 'int') if not isinstance(patch_size, np.ndarray): patch_size = np.array(patch_size, 'int') nb_dims = len(patch_size) # number of dimensions if isinstance(patch_stride, int): patch_stride = np.repeat(patch_stride, nb_dims).astype('int') if isinstance(start_sub, int): start_sub = np.repeat(start_sub, nb_dims).astype('int') # get the grid data [grid_size, new_vol_size] = gridsize(vol_size, patch_size, patch_stride=patch_stride, start_sub=start_sub, nargout=2) # compute grid linear index # prepare the sample grid in each dimension xvec = () for idx in range(nb_dims): volend = new_vol_size[idx] + start_sub[idx] - patch_size[idx] + 1 locs = list(range(start_sub[idx], volend, patch_stride[idx])) xvec += (locs, ) assert any((locs[-1] + patch_size - 1) == (new_vol_size + start_sub - 1)) # get the nd grid # if want subs, this is the faster way to compute in MATLAB (rather than ind -> ind2sub) # TODO: need to investigate for python idx = nd.ndgrid(*xvec) if grid_type == 'idx': # if want index, this is the faster way to compute (rather than sub -> sub2ind all_idx = np.array(list(range(0, np.prod(vol_size)))) all_idx = np.reshape(all_idx, vol_size) idx = all_idx[idx] if nargout == 1: return idx elif nargout == 2: return (idx, new_vol_size) else: return (idx, new_vol_size, grid_size)
def patch_gen(vol, patch_size, stride=1, nargout=1, rand=False, rand_seed=None): """ generator of patches from volume Parameters: vol (numpy array): the n-d volume to be patched patch_size (numpy vector): the size of the patches patch_stride (int or numpy vector, optional): stride (separation) in each dimension. default: 1 nargout (int, optional): how much to yield 1 (default: the patch) or 2 (tuple with the patch and volume slices for that patch) rand (logical, optional): whether to randomize patch order (default: False) rand_seed (number, optional): random seed if randomizing patch order TODO: test more... TODO: use .grid() to get sub """ # some parameter checking if isinstance(stride, int): stride = [stride for f in patch_size] assert len(vol.shape) == len(patch_size), \ "vol shape %s and patch size %s do not match dimensions" \ % (pformat(vol.shape), pformat(patch_size)) assert len(vol.shape) == len(stride), \ "vol shape %s and patch stride %s do not match dimensions" \ % (pformat(vol.shape), pformat(stride)) cropped_vol_size = np.array(vol.shape) - np.array(patch_size) + 1 assert np.all(cropped_vol_size >= 0), \ "patch size needs to be smaller than volume size" # get range subs sub = () for idx, cvs in enumerate(cropped_vol_size): sub += (list(range(0, cvs, stride[idx])), ) # check the size gs = gridsize(vol.shape, patch_size, patch_stride=stride) assert [len(f) for f in sub] == list(gs), 'Patch gen side failure' # get ndgrid of subs ndg = nd.ndgrid(*sub) ndg = [f.flat for f in ndg] # generator rng = list(range(len(ndg[0]))) if rand: if rand_seed is not None: random.seed(rand_seed) shuffle(rng) for idx in rng: slicer = lambda f, g: slice(f[idx], f[idx] + g) patch_sub = [slicer(f, g) for f, g in zip(ndg, patch_size)] # print(patch_sub) if nargout == 1: yield vol[patch_sub] else: yield (vol[patch_sub], patch_sub)