def _augment_tile(img_shape, pos, tile_shape, get_tile_func, augment_params=None, **kwargs): ''' fetch tile and augment it if rotation and shear is activated, a 3 times larger tile is fetched and the final tile is cut out from that after rotation/shear. ''' augment_params = augment_params or {} rotation_angle = augment_params.get('rotation_angle', 0) shear_angle = augment_params.get('shear_angle', 0) pos = np.array(pos) orig_tile_shape = np.array(tile_shape) tile_shape = np.array(tile_shape) augment_fast = (tile_shape[-2:] > 1).any() augment_slow = augment_fast and (rotation_angle > 0 or shear_angle > 0) if augment_slow: pos -= tile_shape tile_shape *= 3 res = inner_tile_size(img_shape, pos, tile_shape) pos_transient, size_transient, pos_inside_transient, pad_size = res tile = get_tile_func(pos=pos_transient, size=size_transient, **kwargs) tile = np.pad(tile, pad_size, mode='symmetric') mesh = ut.get_tile_meshgrid(tile.shape, pos_inside_transient, tile_shape) tile = tile[tuple(mesh)] if augment_fast: # if the requested tile is only of size 1 in x and y, # augmentation can be omitted, since rotation and flipping always # occurs around the center axis. rot90 = augment_params.get('rot90', 0) flipud = augment_params.get('flipud', False) fliplr = augment_params.get('fliplr', False) tile = trafo.flip_image_2d_stack(tile, fliplr=fliplr, flipud=flipud, rot90=rot90) if augment_slow: tile = trafo.warp_image_2d_stack(tile, rotation_angle, shear_angle) mesh = ut.get_tile_meshgrid(tile.shape, orig_tile_shape, orig_tile_shape) tile = tile[tuple(mesh)] return tile
def get_tile_func(pos=None, size=None, img=None): return img[get_tile_meshgrid(img.shape, pos, size)]