def augment_spatial_peaks(data,
                          seg,
                          patch_size,
                          patch_center_dist_from_border=30,
                          do_elastic_deform=True,
                          alpha=(0., 1000.),
                          sigma=(10., 13.),
                          do_rotation=True,
                          angle_x=(0, 2 * np.pi),
                          angle_y=(0, 2 * np.pi),
                          angle_z=(0, 2 * np.pi),
                          do_scale=True,
                          scale=(0.75, 1.25),
                          border_mode_data='nearest',
                          border_cval_data=0,
                          order_data=3,
                          border_mode_seg='constant',
                          border_cval_seg=0,
                          order_seg=0,
                          random_crop=True,
                          p_el_per_sample=1,
                          p_scale_per_sample=1,
                          p_rot_per_sample=1,
                          slice_dir=None):
    dim = len(patch_size)
    seg_result = None
    if seg is not None:
        if dim == 2:
            seg_result = np.zeros(
                (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]),
                dtype=np.float32)
        else:
            seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0],
                                   patch_size[1], patch_size[2]),
                                  dtype=np.float32)

    if dim == 2:
        data_result = np.zeros(
            (data.shape[0], data.shape[1], patch_size[0], patch_size[1]),
            dtype=np.float32)
    else:
        data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0],
                                patch_size[1], patch_size[2]),
                               dtype=np.float32)

    if not isinstance(patch_center_dist_from_border,
                      (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]

    for sample_id in range(data.shape[0]):
        coords = create_zero_centered_coordinate_mesh(patch_size)
        modified_coords = False

        if np.random.uniform() < p_el_per_sample and do_elastic_deform:
            a = np.random.uniform(alpha[0], alpha[1])
            s = np.random.uniform(sigma[0], sigma[1])
            coords = elastic_deform_coordinates(coords, a, s)
            modified_coords = True

        # NEW: initialize all because all needed for rotate_multiple_peaks (even if only rotating along one axis)
        a_x = 0
        a_y = 0
        a_z = 0

        if np.random.uniform() < p_rot_per_sample and do_rotation:
            if angle_x[0] == angle_x[1]:
                a_x = angle_x[0]
            else:
                a_x = np.random.uniform(angle_x[0], angle_x[1])
            if dim == 3:
                if angle_y[0] == angle_y[1]:
                    a_y = angle_y[0]
                else:
                    a_y = np.random.uniform(angle_y[0], angle_y[1])
                if angle_z[0] == angle_z[1]:
                    a_z = angle_z[0]
                else:
                    a_z = np.random.uniform(angle_z[0], angle_z[1])
                coords = rotate_coords_3d(coords, a_x, a_y, a_z)
            else:
                coords = rotate_coords_2d(coords, a_x)
            modified_coords = True

        if np.random.uniform() < p_scale_per_sample and do_scale:
            if np.random.random() < 0.5 and scale[0] < 1:
                sc = np.random.uniform(scale[0], 1)
            else:
                sc = np.random.uniform(max(scale[0], 1), scale[1])
            coords = scale_coords(coords, sc)
            modified_coords = True

        # now find a nice center location
        if modified_coords:
            for d in range(dim):
                if random_crop:
                    ctr = np.random.uniform(
                        patch_center_dist_from_border[d],
                        data.shape[d + 2] - patch_center_dist_from_border[d])
                else:
                    ctr = int(np.round(data.shape[d + 2] / 2.))
                coords[d] += ctr
            for channel_id in range(data.shape[1]):
                data_result[sample_id, channel_id] = interpolate_img(
                    data[sample_id, channel_id],
                    coords,
                    order_data,
                    border_mode_data,
                    cval=border_cval_data)
            if seg is not None:
                for channel_id in range(seg.shape[1]):
                    seg_result[sample_id, channel_id] = interpolate_img(
                        seg[sample_id, channel_id],
                        coords,
                        order_seg,
                        border_mode_seg,
                        cval=border_cval_seg,
                        is_seg=True)
        else:
            if seg is None:
                s = None
            else:
                s = seg[sample_id:sample_id + 1]
            if random_crop:
                margin = [
                    patch_center_dist_from_border[d] - patch_size[d] // 2
                    for d in range(dim)
                ]
                d, s = random_crop_aug(data[sample_id:sample_id + 1], s,
                                       patch_size, margin)
            else:
                d, s = center_crop_aug(data[sample_id:sample_id + 1],
                                       patch_size, s)
            data_result[sample_id] = d[0]
            if seg is not None:
                seg_result[sample_id] = s[0]

        # NEW: Rotate Peaks / Tensors
        if dim > 2:
            raise ValueError(
                "augment_spatial_peaks only supports 2D at the moment")

        sampled_2D_angle = a_x  # if 2D angle will always be a_x even if rotating other axis
        a_x = 0
        a_y = 0
        a_z = 0

        if slice_dir == 0:
            a_x = sampled_2D_angle
        elif slice_dir == 1:
            a_y = sampled_2D_angle
        elif slice_dir == 2:
            # Somehow we have to invert rotation direction for z to make align properly with rotated voxels.
            #  Unclear why this is the case. Maybe some different conventions for peaks and voxels??
            a_z = sampled_2D_angle * -1
        else:
            raise ValueError("invalid slice_dir passed as argument")

        data_aug = data_result[sample_id]
        if data_aug.shape[0] == 9:
            data_result[sample_id] = rotate_multiple_peaks(
                data_aug, a_x, a_y, a_z)
        elif data_aug.shape[0] == 18:
            data_result[sample_id] = rotate_multiple_tensors(
                data_aug, a_x, a_y, a_z)
        else:
            raise ValueError("Incorrect number of channels (expected 9 or 18)")

    return data_result, seg_result
Beispiel #2
0
def augment_spatial(data,
                    seg,
                    patch_size,
                    patch_center_dist_from_border=30,
                    do_elastic_deform=True,
                    alpha=(0., 1000.),
                    sigma=(10., 13.),
                    do_rotation=True,
                    angle_x=(0, 2 * np.pi),
                    angle_y=(0, 2 * np.pi),
                    angle_z=(0, 2 * np.pi),
                    do_scale=True,
                    scale=(0.75, 1.25),
                    border_mode_data='nearest',
                    border_cval_data=0,
                    order_data=3,
                    border_mode_seg='constant',
                    border_cval_seg=0,
                    order_seg=0,
                    random_crop=True,
                    p_el_per_sample=1,
                    p_scale_per_sample=1,
                    p_rot_per_sample=1,
                    independent_scale_for_each_axis=False,
                    p_rot_per_axis: float = 1,
                    p_independent_scale_per_axis: int = 1):
    dim = len(patch_size)
    seg_result = None
    if seg is not None:
        if dim == 2:
            seg_result = np.zeros(
                (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]),
                dtype=np.float32)
        else:
            seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0],
                                   patch_size[1], patch_size[2]),
                                  dtype=np.float32)

    if dim == 2:
        data_result = np.zeros(
            (data.shape[0], data.shape[1], patch_size[0], patch_size[1]),
            dtype=np.float32)
    else:
        data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0],
                                patch_size[1], patch_size[2]),
                               dtype=np.float32)

    if not isinstance(patch_center_dist_from_border,
                      (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]

    for sample_id in range(data.shape[0]):
        coords = create_zero_centered_coordinate_mesh(patch_size)
        modified_coords = False

        if do_elastic_deform and np.random.uniform() < p_el_per_sample:
            a = np.random.uniform(alpha[0], alpha[1])
            s = np.random.uniform(sigma[0], sigma[1])
            coords = elastic_deform_coordinates(coords, a, s)
            modified_coords = True

        if do_rotation and np.random.uniform() < p_rot_per_sample:

            if np.random.uniform() <= p_rot_per_axis:
                a_x = np.random.uniform(angle_x[0], angle_x[1])
            else:
                a_x = 0

            if dim == 3:
                if np.random.uniform() <= p_rot_per_axis:
                    a_y = np.random.uniform(angle_y[0], angle_y[1])
                else:
                    a_y = 0

                if np.random.uniform() <= p_rot_per_axis:
                    a_z = np.random.uniform(angle_z[0], angle_z[1])
                else:
                    a_z = 0

                coords = rotate_coords_3d(coords, a_x, a_y, a_z)
            else:
                coords = rotate_coords_2d(coords, a_x)
            modified_coords = True

        if do_scale and np.random.uniform() < p_scale_per_sample:
            if independent_scale_for_each_axis and np.random.uniform(
            ) < p_independent_scale_per_axis:
                sc = []
                for _ in range(dim):
                    if np.random.random() < 0.5 and scale[0] < 1:
                        sc.append(np.random.uniform(scale[0], 1))
                    else:
                        sc.append(np.random.uniform(max(scale[0], 1),
                                                    scale[1]))
            else:
                if np.random.random() < 0.5 and scale[0] < 1:
                    sc = np.random.uniform(scale[0], 1)
                else:
                    sc = np.random.uniform(max(scale[0], 1), scale[1])

            coords = scale_coords(coords, sc)
            modified_coords = True

        # now find a nice center location
        if modified_coords:
            for d in range(dim):
                if random_crop:
                    ctr = np.random.uniform(
                        patch_center_dist_from_border[d],
                        data.shape[d + 2] - patch_center_dist_from_border[d])
                else:
                    ctr = int(np.round(data.shape[d + 2] / 2.))
                coords[d] += ctr
            for channel_id in range(data.shape[1]):
                data_result[sample_id, channel_id] = interpolate_img(
                    data[sample_id, channel_id],
                    coords,
                    order_data,
                    border_mode_data,
                    cval=border_cval_data)
            if seg is not None:
                for channel_id in range(seg.shape[1]):
                    seg_result[sample_id, channel_id] = interpolate_img(
                        seg[sample_id, channel_id],
                        coords,
                        order_seg,
                        border_mode_seg,
                        cval=border_cval_seg,
                        is_seg=True)
        else:
            if seg is None:
                s = None
            else:
                s = seg[sample_id:sample_id + 1]
            if random_crop:
                margin = [
                    patch_center_dist_from_border[d] - patch_size[d] // 2
                    for d in range(dim)
                ]
                d, s = random_crop_aug(data[sample_id:sample_id + 1], s,
                                       patch_size, margin)
            else:
                d, s = center_crop_aug(data[sample_id:sample_id + 1],
                                       patch_size, s)
            data_result[sample_id] = d[0]
            if seg is not None:
                seg_result[sample_id] = s[0]
    return data_result, seg_result
Beispiel #3
0
def augment_spatial_2(data,
                      seg,
                      patch_size,
                      patch_center_dist_from_border=30,
                      do_elastic_deform=True,
                      deformation_scale=(0, 0.25),
                      do_rotation=True,
                      angle_x=(0, 2 * np.pi),
                      angle_y=(0, 2 * np.pi),
                      angle_z=(0, 2 * np.pi),
                      do_scale=True,
                      scale=(0.75, 1.25),
                      border_mode_data='nearest',
                      border_cval_data=0,
                      order_data=3,
                      border_mode_seg='constant',
                      border_cval_seg=0,
                      order_seg=0,
                      random_crop=True,
                      p_el_per_sample=1,
                      p_scale_per_sample=1,
                      p_rot_per_sample=1,
                      independent_scale_for_each_axis=False,
                      p_rot_per_axis: float = 1,
                      p_independent_scale_per_axis: int = 1):
    """

    :param data:
    :param seg:
    :param patch_size:
    :param patch_center_dist_from_border:
    :param do_elastic_deform:
    :param magnitude: this determines how large the magnitude of the deformation is relative to the patch_size.
    0.125 = 12.5%% of the patch size (in each dimension).
    :param sigma: this determines the scale of the deformation. small values = local deformations,
    large values = large deformations.
    :param do_rotation:
    :param angle_x:
    :param angle_y:
    :param angle_z:
    :param do_scale:
    :param scale:
    :param border_mode_data:
    :param border_cval_data:
    :param order_data:
    :param border_mode_seg:
    :param border_cval_seg:
    :param order_seg:
    :param random_crop:
    :param p_el_per_sample:
    :param p_scale_per_sample:
    :param p_rot_per_sample:
    :param clip_to_safe_magnitude:
    :return:


    """
    dim = len(patch_size)
    seg_result = None
    if seg is not None:
        if dim == 2:
            seg_result = np.zeros(
                (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]),
                dtype=np.float32)
        else:
            seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0],
                                   patch_size[1], patch_size[2]),
                                  dtype=np.float32)

    if dim == 2:
        data_result = np.zeros(
            (data.shape[0], data.shape[1], patch_size[0], patch_size[1]),
            dtype=np.float32)
    else:
        data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0],
                                patch_size[1], patch_size[2]),
                               dtype=np.float32)

    if not isinstance(patch_center_dist_from_border,
                      (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]

    for sample_id in range(data.shape[0]):
        coords = create_zero_centered_coordinate_mesh(patch_size)
        modified_coords = False

        if np.random.uniform() < p_el_per_sample and do_elastic_deform:
            mag = []
            sigmas = []

            # one scale per case, scale is in percent of patch_size
            def_scale = np.random.uniform(deformation_scale[0],
                                          deformation_scale[1])

            for d in range(len(data[sample_id].shape) - 1):
                # transform relative def_scale in pixels
                sigmas.append(def_scale * patch_size[d])

                # define max magnitude and min_magnitude
                max_magnitude = sigmas[-1] * (1 / 2)
                min_magnitude = sigmas[-1] * (1 / 8)

                # the magnitude needs to depend on the scale, otherwise not much is going to happen most of the time.
                # we want the magnitude to be high, but not higher than max_magnitude (otherwise the deformations
                # become very ugly). Let's sample mag_real with a gaussian
                # mag_real = np.random.normal(max_magnitude * (2 / 3), scale=max_magnitude / 3)
                # clip to make sure we stay reasonable
                # mag_real = np.clip(mag_real, 0, max_magnitude)

                mag_real = np.random.uniform(min_magnitude, max_magnitude)

                mag.append(mag_real)
            # print(np.round(sigmas, decimals=3), np.round(mag, decimals=3))
            coords = elastic_deform_coordinates_2(coords, sigmas, mag)
            modified_coords = True

        if do_rotation and np.random.uniform() < p_rot_per_sample:

            if np.random.uniform() <= p_rot_per_axis:
                a_x = np.random.uniform(angle_x[0], angle_x[1])
            else:
                a_x = 0

            if dim == 3:
                if np.random.uniform() <= p_rot_per_axis:
                    a_y = np.random.uniform(angle_y[0], angle_y[1])
                else:
                    a_y = 0

                if np.random.uniform() <= p_rot_per_axis:
                    a_z = np.random.uniform(angle_z[0], angle_z[1])
                else:
                    a_z = 0

                coords = rotate_coords_3d(coords, a_x, a_y, a_z)
            else:
                coords = rotate_coords_2d(coords, a_x)
            modified_coords = True

        if do_scale and np.random.uniform() < p_scale_per_sample:
            if independent_scale_for_each_axis and np.random.uniform(
            ) < p_independent_scale_per_axis:
                sc = []
                for _ in range(dim):
                    if np.random.random() < 0.5 and scale[0] < 1:
                        sc.append(np.random.uniform(scale[0], 1))
                    else:
                        sc.append(np.random.uniform(max(scale[0], 1),
                                                    scale[1]))
            else:
                if np.random.random() < 0.5 and scale[0] < 1:
                    sc = np.random.uniform(scale[0], 1)
                else:
                    sc = np.random.uniform(max(scale[0], 1), scale[1])

            coords = scale_coords(coords, sc)
            modified_coords = True

        # now find a nice center location
        if modified_coords:
            # recenter coordinates
            coords_mean = coords.mean(axis=tuple(range(1, len(coords.shape))),
                                      keepdims=True)
            coords -= coords_mean

            for d in range(dim):
                if random_crop:
                    ctr = np.random.uniform(
                        patch_center_dist_from_border[d],
                        data.shape[d + 2] - patch_center_dist_from_border[d])
                else:
                    ctr = int(np.round(data.shape[d + 2] / 2.))
                coords[d] += ctr
            for channel_id in range(data.shape[1]):
                data_result[sample_id, channel_id] = interpolate_img(
                    data[sample_id, channel_id],
                    coords,
                    order_data,
                    border_mode_data,
                    cval=border_cval_data)
            if seg is not None:
                for channel_id in range(seg.shape[1]):
                    seg_result[sample_id, channel_id] = interpolate_img(
                        seg[sample_id, channel_id],
                        coords,
                        order_seg,
                        border_mode_seg,
                        cval=border_cval_seg,
                        is_seg=True)
        else:
            if seg is None:
                s = None
            else:
                s = seg[sample_id:sample_id + 1]
            if random_crop:
                margin = [
                    patch_center_dist_from_border[d] - patch_size[d] // 2
                    for d in range(dim)
                ]
                d, s = random_crop_aug(data[sample_id:sample_id + 1], s,
                                       patch_size, margin)
            else:
                d, s = center_crop_aug(data[sample_id:sample_id + 1],
                                       patch_size, s)
            data_result[sample_id] = d[0]
            if seg is not None:
                seg_result[sample_id] = s[0]
    return data_result, seg_result
Beispiel #4
0
def augment_spatial_transform(data, seg=None, patch_size=None, patch_center_dist_from_border=None,
                              do_elastic_deform=True, alpha=(0., 900.), sigma=(9., 13.), do_rotation=True,
                              angle_x=default_3D_augmentation_params.get("rotation_x"),
                              angle_y=default_3D_augmentation_params.get("rotation_y"),
                              angle_z=default_3D_augmentation_params.get("rotation_z"),
                              do_scale=True, scale=(0.85, 1.25),
                              border_mode_data='constant', border_cval_data=0, order_data=3,
                              border_mode_seg="constant", border_cval_seg=-1, order_seg=1,
                              random_crop=False, p_el_per_sample=1,
                              p_scale_per_sample=1, p_rot_per_sample=1):
    if len(data.shape) == 3:
        print("Attention: data shape must be (C, D, H, W). I will transform it.")
        data = data[None]
        if seg is not None and len(seg.shape) == 3:
            seg = seg[None]
    if patch_size is None:
        patch_size = (data.shape[1], data.shape[2], data.shape[3])

    dim = len(patch_size)
    seg_result = None
    if seg is not None:
        if dim == 2:
            seg_result = np.zeros((seg.shape[0], patch_size[0], patch_size[1]), dtype=np.float32)
        else:
            seg_result = np.zeros((seg.shape[0], patch_size[0], patch_size[1], patch_size[2]),
                                  dtype=np.float32)

    if dim == 2:
        data_result = np.zeros((data.shape[0], patch_size[0], patch_size[1]), dtype=np.float32)
    else:
        data_result = np.zeros((data.shape[0], patch_size[0], patch_size[1], patch_size[2]),
                               dtype=np.float32)

    if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]

    coords = create_zero_centered_coordinate_mesh(patch_size)
    modified_coords = False
    if np.random.uniform() < p_el_per_sample and do_elastic_deform:
        a = np.random.uniform(alpha[0], alpha[1])
        s = np.random.uniform(sigma[0], sigma[1])
        coords = elastic_deform_coordinates(coords, a, s)
        modified_coords = True

    if np.random.uniform() < p_rot_per_sample and do_rotation:
        if angle_x[0] == angle_x[1]:
            a_x = angle_x[0]
        else:
            a_x = np.random.uniform(angle_x[0], angle_x[1])
        if dim == 3:
            if angle_y[0] == angle_y[1]:
                a_y = angle_y[0]
            else:
                a_y = np.random.uniform(angle_y[0], angle_y[1])
            if angle_z[0] == angle_z[1]:
                a_z = angle_z[0]
            else:
                a_z = np.random.uniform(angle_z[0], angle_z[1])
            coords = rotate_coords_3d(coords, a_x, a_y, a_z)
        else:
            coords = rotate_coords_2d(coords, a_x)
        modified_coords = True

    if np.random.uniform() < p_scale_per_sample and do_scale:
        if np.random.random() < 0.5 and scale[0] < 1:
            sc = np.random.uniform(scale[0], 1)
        else:
            sc = np.random.uniform(max(scale[0], 1), scale[1])
        coords = scale_coords(coords, sc)
        modified_coords = True

    # now find a nice center location
    if modified_coords:
        for d in range(dim):
            if random_crop:
                ctr = np.random.uniform(patch_center_dist_from_border[d],
                                        data.shape[d + 1] - patch_center_dist_from_border[d])
            else:
                ctr = int(np.around(data.shape[d + 1] / 2.))
            coords[d] += ctr
        for channel_id in range(data.shape[0]):
            data_result[channel_id] = interpolate_img(data[channel_id], coords, order_data,
                                                      border_mode_data, cval=border_cval_data)
        if seg is not None:
            for channel_id in range(seg.shape[0]):
                seg_result[channel_id] = interpolate_img(seg[channel_id], coords, order_seg,
                                                         border_mode_seg, cval=border_cval_seg, is_seg=True)
    else:
        if seg is None:
            s = None
        else:
            s = seg[None]
        if random_crop:
            margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)]
            d, s = random_crop_aug(data[None], s, patch_size, margin)
        else:
            d, s = center_crop_aug(data[None], patch_size, s)
        data_result = d[0]
        if seg is not None:
            seg_result = s[0]
    return data_result, seg_result