Esempio n. 1
0
def _augment_tile(img_shape,
                  pos,
                  tile_shape,
                  get_tile_func,
                  augment_params=None,
                  **kwargs):
    '''
    fetch tile and augment it
    if rotation and shear is activated, a 3 times larger tile
    is fetched and the final tile is cut out from that after
    rotation/shear.
    '''
    augment_params = augment_params or {}
    rotation_angle = augment_params.get('rotation_angle', 0)
    shear_angle = augment_params.get('shear_angle', 0)

    pos = np.array(pos)
    orig_tile_shape = np.array(tile_shape)
    tile_shape = np.array(tile_shape)

    augment_fast = (tile_shape[-2:] > 1).any()
    augment_slow = augment_fast and (rotation_angle > 0 or shear_angle > 0)

    if augment_slow:
        pos -= tile_shape
        tile_shape *= 3

    res = inner_tile_size(img_shape, pos, tile_shape)
    pos_transient, size_transient, pos_inside_transient, pad_size = res

    tile = get_tile_func(pos=pos_transient, size=size_transient, **kwargs)

    tile = np.pad(tile, pad_size, mode='symmetric')
    mesh = ut.get_tile_meshgrid(tile.shape, pos_inside_transient, tile_shape)
    tile = tile[tuple(mesh)]

    if augment_fast:
        # if the requested tile is only of size 1 in x and y,
        # augmentation can be omitted, since rotation and flipping always
        # occurs around the center axis.
        rot90 = augment_params.get('rot90', 0)
        flipud = augment_params.get('flipud', False)
        fliplr = augment_params.get('fliplr', False)

        tile = trafo.flip_image_2d_stack(tile, fliplr=fliplr,
                                         flipud=flipud, rot90=rot90)

    if augment_slow:
        tile = trafo.warp_image_2d_stack(tile, rotation_angle, shear_angle)
        mesh = ut.get_tile_meshgrid(tile.shape, orig_tile_shape,
                                    orig_tile_shape)
        tile = tile[tuple(mesh)]

    return tile
Esempio n. 2
0
 def get_tile_func(pos=None, size=None, img=None):
     return img[get_tile_meshgrid(img.shape, pos, size)]