def rotate(self, data_dict):
        data = data_dict["data"]
        do_seg = False
        seg = None
        if "seg" in list(data_dict.keys()):
            seg = data_dict["seg"]
            do_seg = True
        shape = np.array(data.shape[2:])
        dim = len(shape)
        for sample_id in range(data.shape[0]):
            coords = create_zero_centered_coordinate_mesh(shape)

            if dim == 3:
                coords = rotate_coords_3d(coords,
                                          self.rand_params['ax'][sample_id],
                                          self.rand_params['ay'][sample_id],
                                          self.rand_params['az'][sample_id])
            else:
                coords = rotate_coords_2d(coords, self.rand_params['ax'][sample_id])
            coords = uncenter_coords(coords)
            for channel_id in range(data.shape[1]):
                data[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords,
                                                              self.params['order_data'], self.params['bmode_data'],
                                                              cval=self.params['bcval_data'])
            if do_seg:
                for channel_id in range(seg.shape[1]):
                    seg[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords,
                                                                 self.params['order_seg'], self.params['bmode_seg'],
                                                                 cval=self.params['bcval_seg'], is_seg=True)

        return {'data': data, 'seg': seg}
def batch_deform_3d(vol, coords, bspline_order, border_mode, constant_val):
    data_result = np.zeros_like(vol, dtype=np.float32)
    for sample_id in range(vol.shape[0]):
        for channel_id in range(vol.shape[1]):
            data_result[sample_id,
                        channel_id] = interpolate_img(vol[sample_id,
                                                          channel_id],
                                                      coords,
                                                      order=bspline_order,
                                                      mode=border_mode,
                                                      cval=constant_val)
    return data_result
 def augment_spatial(self, data, coords, is_label=False):
     if is_label:
         order = self.order_seg
         border_mode = self.border_mode_seg
         border_cval = self.border_cval_seg
     else:
         order = self.order_data
         border_mode = self.border_mode_data
         border_cval = self.border_cval_data
     data = interpolate_img(data,
                            coords,
                            order,
                            border_mode,
                            cval=border_cval)
     return data
Exemple #4
0
def augment_spatial(data,
                    seg,
                    patch_size,
                    labels_extra=None,
                    patch_center_dist_from_border=30,
                    do_elastic_deform=True,
                    alpha=(0., 1000.),
                    sigma=(10., 13.),
                    do_rotation=True,
                    angle_x=(0, 2 * np.pi),
                    angle_y=(0, 2 * np.pi),
                    angle_z=(0, 2 * np.pi),
                    do_scale=True,
                    scale=(0.75, 1.25),
                    border_mode_data='nearest',
                    border_cval_data=0,
                    order_data=3,
                    border_mode_seg='constant',
                    border_cval_seg=0,
                    order_seg=0,
                    random_crop=True,
                    p_el_per_sample=1,
                    p_scale_per_sample=1,
                    p_rot_per_sample=1,
                    independent_scale_for_each_axis=False,
                    p_rot_per_axis: float = 1):
    dim = len(patch_size)
    seg_result = None
    if seg is not None:
        if dim == 2:
            seg_result = np.zeros(
                (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]),
                dtype=np.float32)
        else:
            seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0],
                                   patch_size[1], patch_size[2]),
                                  dtype=np.float32)

    # shuailin: extra segs
    seg_extra_outputs = []
    if labels_extra is not None:
        for _ in range(len(labels_extra)):
            if dim == 2:
                seg_extra_outputs.append(
                    np.zeros((seg.shape[0], seg.shape[1], patch_size[0],
                              patch_size[1]),
                             dtype=np.float32))
            else:
                seg_extra_outputs.append(
                    np.zeros((seg.shape[0], seg.shape[1], patch_size[0],
                              patch_size[1], patch_size[2]),
                             dtype=np.float32))

    if dim == 2:
        data_result = np.zeros(
            (data.shape[0], data.shape[1], patch_size[0], patch_size[1]),
            dtype=np.float32)
    else:
        data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0],
                                patch_size[1], patch_size[2]),
                               dtype=np.float32)

    if not isinstance(patch_center_dist_from_border,
                      (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]

    for sample_id in range(data.shape[0]):
        coords = create_zero_centered_coordinate_mesh(patch_size)
        modified_coords = False

        if do_elastic_deform and np.random.uniform() < p_el_per_sample:
            a = np.random.uniform(alpha[0], alpha[1])
            s = np.random.uniform(sigma[0], sigma[1])
            coords = elastic_deform_coordinates(coords, a, s)
            modified_coords = True

        if do_rotation and np.random.uniform() < p_rot_per_sample:

            if np.random.uniform() <= p_rot_per_axis:
                a_x = np.random.uniform(angle_x[0], angle_x[1])
            else:
                a_x = 0

            if dim == 3:
                if np.random.uniform() <= p_rot_per_axis:
                    a_y = np.random.uniform(angle_y[0], angle_y[1])
                else:
                    a_y = 0

                if np.random.uniform() <= p_rot_per_axis:
                    a_z = np.random.uniform(angle_z[0], angle_z[1])
                else:
                    a_z = 0

                coords = rotate_coords_3d(coords, a_x, a_y, a_z)
            else:
                coords = rotate_coords_2d(coords, a_x)
            modified_coords = True

        if do_scale and np.random.uniform() < p_scale_per_sample:
            if not independent_scale_for_each_axis:
                if np.random.random() < 0.5 and scale[0] < 1:
                    sc = np.random.uniform(scale[0], 1)
                else:
                    sc = np.random.uniform(max(scale[0], 1), scale[1])
            else:
                sc = []
                for _ in range(dim):
                    if np.random.random() < 0.5 and scale[0] < 1:
                        sc.append(np.random.uniform(scale[0], 1))
                    else:
                        sc.append(np.random.uniform(max(scale[0], 1),
                                                    scale[1]))
            coords = scale_coords(coords, sc)
            modified_coords = True

        # now find a nice center location
        if modified_coords:
            for d in range(dim):
                if random_crop:
                    ctr = np.random.uniform(
                        patch_center_dist_from_border[d],
                        data.shape[d + 2] - patch_center_dist_from_border[d])
                else:
                    ctr = int(np.round(data.shape[d + 2] / 2.))
                coords[d] += ctr
            for channel_id in range(data.shape[1]):
                data_result[sample_id, channel_id] = interpolate_img(
                    data[sample_id, channel_id],
                    coords,
                    order_data,
                    border_mode_data,
                    cval=border_cval_data)
            if seg is not None:
                for channel_id in range(seg.shape[1]):
                    seg_result[sample_id, channel_id] = interpolate_img(
                        seg[sample_id, channel_id],
                        coords,
                        order_seg,
                        border_mode_seg,
                        cval=border_cval_seg,
                        is_seg=True)
            # shuailin: extra segs transformation
            if labels_extra is not None:
                for ith, seg_arr in enumerate(labels_extra):
                    for channel_id in range(seg_arr.shape[1]):
                        seg_extra_outputs[ith][sample_id,
                                               channel_id] = interpolate_img(
                                                   seg_arr[sample_id,
                                                           channel_id],
                                                   coords,
                                                   order_seg,
                                                   border_mode_seg,
                                                   cval=border_cval_seg,
                                                   is_seg=True)
        else:
            if seg is None:
                s = None
            else:
                s = seg[sample_id:sample_id + 1]
            # shuailin extension
            if labels_extra is not None:
                labels_extra_sample = []
                for ith in range(len(labels_extra)):
                    labels_extra_sample.append(
                        labels_extra[ith][sample_id:sample_id + 1])
            else:
                labels_extra_sample = None

            if random_crop:
                margin = [
                    patch_center_dist_from_border[d] - patch_size[d] // 2
                    for d in range(dim)
                ]
                d, s, seg_extra_output = random_crop_aug(
                    data[sample_id:sample_id + 1],
                    s,
                    patch_size,
                    margin,
                    labels_extra=labels_extra_sample)
            else:
                d, s, seg_extra_output = center_crop_aug(
                    data[sample_id:sample_id + 1],
                    patch_size,
                    s,
                    labels_extra=labels_extra_sample)
            data_result[sample_id] = d[0]
            if seg is not None:
                seg_result[sample_id] = s[0]
            # shuailin extension
            if labels_extra is not None:
                for ith in range(len(labels_extra)):
                    seg_extra_outputs[ith][sample_id] = seg_extra_output[ith]

    return data_result, seg_result, seg_extra_outputs
def augment_spatial_peaks(data,
                          seg,
                          patch_size,
                          patch_center_dist_from_border=30,
                          do_elastic_deform=True,
                          alpha=(0., 1000.),
                          sigma=(10., 13.),
                          do_rotation=True,
                          angle_x=(0, 2 * np.pi),
                          angle_y=(0, 2 * np.pi),
                          angle_z=(0, 2 * np.pi),
                          do_scale=True,
                          scale=(0.75, 1.25),
                          border_mode_data='nearest',
                          border_cval_data=0,
                          order_data=3,
                          border_mode_seg='constant',
                          border_cval_seg=0,
                          order_seg=0,
                          random_crop=True,
                          p_el_per_sample=1,
                          p_scale_per_sample=1,
                          p_rot_per_sample=1,
                          slice_dir=None):
    dim = len(patch_size)
    seg_result = None
    if seg is not None:
        if dim == 2:
            seg_result = np.zeros(
                (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]),
                dtype=np.float32)
        else:
            seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0],
                                   patch_size[1], patch_size[2]),
                                  dtype=np.float32)

    if dim == 2:
        data_result = np.zeros(
            (data.shape[0], data.shape[1], patch_size[0], patch_size[1]),
            dtype=np.float32)
    else:
        data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0],
                                patch_size[1], patch_size[2]),
                               dtype=np.float32)

    if not isinstance(patch_center_dist_from_border,
                      (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]

    for sample_id in range(data.shape[0]):
        coords = create_zero_centered_coordinate_mesh(patch_size)
        modified_coords = False

        if np.random.uniform() < p_el_per_sample and do_elastic_deform:
            a = np.random.uniform(alpha[0], alpha[1])
            s = np.random.uniform(sigma[0], sigma[1])
            coords = elastic_deform_coordinates(coords, a, s)
            modified_coords = True

        # NEW: initialize all because all needed for rotate_multiple_peaks (even if only rotating along one axis)
        a_x = 0
        a_y = 0
        a_z = 0

        if np.random.uniform() < p_rot_per_sample and do_rotation:
            if angle_x[0] == angle_x[1]:
                a_x = angle_x[0]
            else:
                a_x = np.random.uniform(angle_x[0], angle_x[1])
            if dim == 3:
                if angle_y[0] == angle_y[1]:
                    a_y = angle_y[0]
                else:
                    a_y = np.random.uniform(angle_y[0], angle_y[1])
                if angle_z[0] == angle_z[1]:
                    a_z = angle_z[0]
                else:
                    a_z = np.random.uniform(angle_z[0], angle_z[1])
                coords = rotate_coords_3d(coords, a_x, a_y, a_z)
            else:
                coords = rotate_coords_2d(coords, a_x)
            modified_coords = True

        if np.random.uniform() < p_scale_per_sample and do_scale:
            if np.random.random() < 0.5 and scale[0] < 1:
                sc = np.random.uniform(scale[0], 1)
            else:
                sc = np.random.uniform(max(scale[0], 1), scale[1])
            coords = scale_coords(coords, sc)
            modified_coords = True

        # now find a nice center location
        if modified_coords:
            for d in range(dim):
                if random_crop:
                    ctr = np.random.uniform(
                        patch_center_dist_from_border[d],
                        data.shape[d + 2] - patch_center_dist_from_border[d])
                else:
                    ctr = int(np.round(data.shape[d + 2] / 2.))
                coords[d] += ctr
            for channel_id in range(data.shape[1]):
                data_result[sample_id, channel_id] = interpolate_img(
                    data[sample_id, channel_id],
                    coords,
                    order_data,
                    border_mode_data,
                    cval=border_cval_data)
            if seg is not None:
                for channel_id in range(seg.shape[1]):
                    seg_result[sample_id, channel_id] = interpolate_img(
                        seg[sample_id, channel_id],
                        coords,
                        order_seg,
                        border_mode_seg,
                        cval=border_cval_seg,
                        is_seg=True)
        else:
            if seg is None:
                s = None
            else:
                s = seg[sample_id:sample_id + 1]
            if random_crop:
                margin = [
                    patch_center_dist_from_border[d] - patch_size[d] // 2
                    for d in range(dim)
                ]
                d, s = random_crop_aug(data[sample_id:sample_id + 1], s,
                                       patch_size, margin)
            else:
                d, s = center_crop_aug(data[sample_id:sample_id + 1],
                                       patch_size, s)
            data_result[sample_id] = d[0]
            if seg is not None:
                seg_result[sample_id] = s[0]

        # NEW: Rotate Peaks / Tensors
        if dim > 2:
            raise ValueError(
                "augment_spatial_peaks only supports 2D at the moment")

        sampled_2D_angle = a_x  # if 2D angle will always be a_x even if rotating other axis
        a_x = 0
        a_y = 0
        a_z = 0

        if slice_dir == 0:
            a_x = sampled_2D_angle
        elif slice_dir == 1:
            a_y = sampled_2D_angle
        elif slice_dir == 2:
            # Somehow we have to invert rotation direction for z to make align properly with rotated voxels.
            #  Unclear why this is the case. Maybe some different conventions for peaks and voxels??
            a_z = sampled_2D_angle * -1
        else:
            raise ValueError("invalid slice_dir passed as argument")

        data_aug = data_result[sample_id]
        if data_aug.shape[0] == 9:
            data_result[sample_id] = rotate_multiple_peaks(
                data_aug, a_x, a_y, a_z)
        elif data_aug.shape[0] == 18:
            data_result[sample_id] = rotate_multiple_tensors(
                data_aug, a_x, a_y, a_z)
        else:
            raise ValueError("Incorrect number of channels (expected 9 or 18)")

    return data_result, seg_result
Exemple #6
0
def augment_spatial_2(data,
                      seg,
                      patch_size,
                      patch_center_dist_from_border=30,
                      do_elastic_deform=True,
                      deformation_scale=(0, 0.25),
                      do_rotation=True,
                      angle_x=(0, 2 * np.pi),
                      angle_y=(0, 2 * np.pi),
                      angle_z=(0, 2 * np.pi),
                      do_scale=True,
                      scale=(0.75, 1.25),
                      border_mode_data='nearest',
                      border_cval_data=0,
                      order_data=3,
                      border_mode_seg='constant',
                      border_cval_seg=0,
                      order_seg=0,
                      random_crop=True,
                      p_el_per_sample=1,
                      p_scale_per_sample=1,
                      p_rot_per_sample=1,
                      independent_scale_for_each_axis=False,
                      p_rot_per_axis: float = 1,
                      p_independent_scale_per_axis: int = 1):
    """

    :param data:
    :param seg:
    :param patch_size:
    :param patch_center_dist_from_border:
    :param do_elastic_deform:
    :param magnitude: this determines how large the magnitude of the deformation is relative to the patch_size.
    0.125 = 12.5%% of the patch size (in each dimension).
    :param sigma: this determines the scale of the deformation. small values = local deformations,
    large values = large deformations.
    :param do_rotation:
    :param angle_x:
    :param angle_y:
    :param angle_z:
    :param do_scale:
    :param scale:
    :param border_mode_data:
    :param border_cval_data:
    :param order_data:
    :param border_mode_seg:
    :param border_cval_seg:
    :param order_seg:
    :param random_crop:
    :param p_el_per_sample:
    :param p_scale_per_sample:
    :param p_rot_per_sample:
    :param clip_to_safe_magnitude:
    :return:


    """
    dim = len(patch_size)
    seg_result = None
    if seg is not None:
        if dim == 2:
            seg_result = np.zeros(
                (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]),
                dtype=np.float32)
        else:
            seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0],
                                   patch_size[1], patch_size[2]),
                                  dtype=np.float32)

    if dim == 2:
        data_result = np.zeros(
            (data.shape[0], data.shape[1], patch_size[0], patch_size[1]),
            dtype=np.float32)
    else:
        data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0],
                                patch_size[1], patch_size[2]),
                               dtype=np.float32)

    if not isinstance(patch_center_dist_from_border,
                      (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]

    for sample_id in range(data.shape[0]):
        coords = create_zero_centered_coordinate_mesh(patch_size)
        modified_coords = False

        if np.random.uniform() < p_el_per_sample and do_elastic_deform:
            mag = []
            sigmas = []

            # one scale per case, scale is in percent of patch_size
            def_scale = np.random.uniform(deformation_scale[0],
                                          deformation_scale[1])

            for d in range(len(data[sample_id].shape) - 1):
                # transform relative def_scale in pixels
                sigmas.append(def_scale * patch_size[d])

                # define max magnitude and min_magnitude
                max_magnitude = sigmas[-1] * (1 / 2)
                min_magnitude = sigmas[-1] * (1 / 8)

                # the magnitude needs to depend on the scale, otherwise not much is going to happen most of the time.
                # we want the magnitude to be high, but not higher than max_magnitude (otherwise the deformations
                # become very ugly). Let's sample mag_real with a gaussian
                # mag_real = np.random.normal(max_magnitude * (2 / 3), scale=max_magnitude / 3)
                # clip to make sure we stay reasonable
                # mag_real = np.clip(mag_real, 0, max_magnitude)

                mag_real = np.random.uniform(min_magnitude, max_magnitude)

                mag.append(mag_real)
            # print(np.round(sigmas, decimals=3), np.round(mag, decimals=3))
            coords = elastic_deform_coordinates_2(coords, sigmas, mag)
            modified_coords = True

        if do_rotation and np.random.uniform() < p_rot_per_sample:

            if np.random.uniform() <= p_rot_per_axis:
                a_x = np.random.uniform(angle_x[0], angle_x[1])
            else:
                a_x = 0

            if dim == 3:
                if np.random.uniform() <= p_rot_per_axis:
                    a_y = np.random.uniform(angle_y[0], angle_y[1])
                else:
                    a_y = 0

                if np.random.uniform() <= p_rot_per_axis:
                    a_z = np.random.uniform(angle_z[0], angle_z[1])
                else:
                    a_z = 0

                coords = rotate_coords_3d(coords, a_x, a_y, a_z)
            else:
                coords = rotate_coords_2d(coords, a_x)
            modified_coords = True

        if do_scale and np.random.uniform() < p_scale_per_sample:
            if independent_scale_for_each_axis and np.random.uniform(
            ) < p_independent_scale_per_axis:
                sc = []
                for _ in range(dim):
                    if np.random.random() < 0.5 and scale[0] < 1:
                        sc.append(np.random.uniform(scale[0], 1))
                    else:
                        sc.append(np.random.uniform(max(scale[0], 1),
                                                    scale[1]))
            else:
                if np.random.random() < 0.5 and scale[0] < 1:
                    sc = np.random.uniform(scale[0], 1)
                else:
                    sc = np.random.uniform(max(scale[0], 1), scale[1])

            coords = scale_coords(coords, sc)
            modified_coords = True

        # now find a nice center location
        if modified_coords:
            # recenter coordinates
            coords_mean = coords.mean(axis=tuple(range(1, len(coords.shape))),
                                      keepdims=True)
            coords -= coords_mean

            for d in range(dim):
                if random_crop:
                    ctr = np.random.uniform(
                        patch_center_dist_from_border[d],
                        data.shape[d + 2] - patch_center_dist_from_border[d])
                else:
                    ctr = int(np.round(data.shape[d + 2] / 2.))
                coords[d] += ctr
            for channel_id in range(data.shape[1]):
                data_result[sample_id, channel_id] = interpolate_img(
                    data[sample_id, channel_id],
                    coords,
                    order_data,
                    border_mode_data,
                    cval=border_cval_data)
            if seg is not None:
                for channel_id in range(seg.shape[1]):
                    seg_result[sample_id, channel_id] = interpolate_img(
                        seg[sample_id, channel_id],
                        coords,
                        order_seg,
                        border_mode_seg,
                        cval=border_cval_seg,
                        is_seg=True)
        else:
            if seg is None:
                s = None
            else:
                s = seg[sample_id:sample_id + 1]
            if random_crop:
                margin = [
                    patch_center_dist_from_border[d] - patch_size[d] // 2
                    for d in range(dim)
                ]
                d, s = random_crop_aug(data[sample_id:sample_id + 1], s,
                                       patch_size, margin)
            else:
                d, s = center_crop_aug(data[sample_id:sample_id + 1],
                                       patch_size, s)
            data_result[sample_id] = d[0]
            if seg is not None:
                seg_result[sample_id] = s[0]
    return data_result, seg_result
Exemple #7
0
def augment_spatial_transform(data, seg=None, patch_size=None, patch_center_dist_from_border=None,
                              do_elastic_deform=True, alpha=(0., 900.), sigma=(9., 13.), do_rotation=True,
                              angle_x=default_3D_augmentation_params.get("rotation_x"),
                              angle_y=default_3D_augmentation_params.get("rotation_y"),
                              angle_z=default_3D_augmentation_params.get("rotation_z"),
                              do_scale=True, scale=(0.85, 1.25),
                              border_mode_data='constant', border_cval_data=0, order_data=3,
                              border_mode_seg="constant", border_cval_seg=-1, order_seg=1,
                              random_crop=False, p_el_per_sample=1,
                              p_scale_per_sample=1, p_rot_per_sample=1):
    if len(data.shape) == 3:
        print("Attention: data shape must be (C, D, H, W). I will transform it.")
        data = data[None]
        if seg is not None and len(seg.shape) == 3:
            seg = seg[None]
    if patch_size is None:
        patch_size = (data.shape[1], data.shape[2], data.shape[3])

    dim = len(patch_size)
    seg_result = None
    if seg is not None:
        if dim == 2:
            seg_result = np.zeros((seg.shape[0], patch_size[0], patch_size[1]), dtype=np.float32)
        else:
            seg_result = np.zeros((seg.shape[0], patch_size[0], patch_size[1], patch_size[2]),
                                  dtype=np.float32)

    if dim == 2:
        data_result = np.zeros((data.shape[0], patch_size[0], patch_size[1]), dtype=np.float32)
    else:
        data_result = np.zeros((data.shape[0], patch_size[0], patch_size[1], patch_size[2]),
                               dtype=np.float32)

    if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]

    coords = create_zero_centered_coordinate_mesh(patch_size)
    modified_coords = False
    if np.random.uniform() < p_el_per_sample and do_elastic_deform:
        a = np.random.uniform(alpha[0], alpha[1])
        s = np.random.uniform(sigma[0], sigma[1])
        coords = elastic_deform_coordinates(coords, a, s)
        modified_coords = True

    if np.random.uniform() < p_rot_per_sample and do_rotation:
        if angle_x[0] == angle_x[1]:
            a_x = angle_x[0]
        else:
            a_x = np.random.uniform(angle_x[0], angle_x[1])
        if dim == 3:
            if angle_y[0] == angle_y[1]:
                a_y = angle_y[0]
            else:
                a_y = np.random.uniform(angle_y[0], angle_y[1])
            if angle_z[0] == angle_z[1]:
                a_z = angle_z[0]
            else:
                a_z = np.random.uniform(angle_z[0], angle_z[1])
            coords = rotate_coords_3d(coords, a_x, a_y, a_z)
        else:
            coords = rotate_coords_2d(coords, a_x)
        modified_coords = True

    if np.random.uniform() < p_scale_per_sample and do_scale:
        if np.random.random() < 0.5 and scale[0] < 1:
            sc = np.random.uniform(scale[0], 1)
        else:
            sc = np.random.uniform(max(scale[0], 1), scale[1])
        coords = scale_coords(coords, sc)
        modified_coords = True

    # now find a nice center location
    if modified_coords:
        for d in range(dim):
            if random_crop:
                ctr = np.random.uniform(patch_center_dist_from_border[d],
                                        data.shape[d + 1] - patch_center_dist_from_border[d])
            else:
                ctr = int(np.around(data.shape[d + 1] / 2.))
            coords[d] += ctr
        for channel_id in range(data.shape[0]):
            data_result[channel_id] = interpolate_img(data[channel_id], coords, order_data,
                                                      border_mode_data, cval=border_cval_data)
        if seg is not None:
            for channel_id in range(seg.shape[0]):
                seg_result[channel_id] = interpolate_img(seg[channel_id], coords, order_seg,
                                                         border_mode_seg, cval=border_cval_seg, is_seg=True)
    else:
        if seg is None:
            s = None
        else:
            s = seg[None]
        if random_crop:
            margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)]
            d, s = random_crop_aug(data[None], s, patch_size, margin)
        else:
            d, s = center_crop_aug(data[None], patch_size, s)
        data_result = d[0]
        if seg is not None:
            seg_result = s[0]
    return data_result, seg_result
Exemple #8
0
def cc_augment(config_task,
               data,
               seg,
               patch_type,
               patch_size,
               patch_center_dist_from_border=30,
               do_elastic_deform=True,
               alpha=(0., 1000.),
               sigma=(10., 13.),
               do_rotation=True,
               angle_x=(0, 2 * np.pi),
               angle_y=(0, 2 * np.pi),
               angle_z=(0, 2 * np.pi),
               do_scale=True,
               scale=(0.75, 1.25),
               border_mode_data='constant',
               border_cval_data=0,
               order_data=3,
               border_mode_seg='constant',
               border_cval_seg=0,
               order_seg=0,
               random_crop=True,
               p_el_per_sample=1,
               p_scale_per_sample=1,
               p_rot_per_sample=1,
               tag=''):
    # patch_center_dist_from_border should be no more than 1/2 patch size. otherwise code not available.

    # data: [n,c,d,h,w]
    # seg: [n,c,d,h,w]
    dim = len(patch_size)

    seg_result = None
    if seg is not None:
        seg_result = np.zeros([seg.shape[0], seg.shape[1]] + patch_size,
                              dtype=np.float32)

    data_result = np.zeros([data.shape[0], data.shape[1]] + patch_size,
                           dtype=np.float32)

    if not isinstance(patch_center_dist_from_border,
                      (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]

    ## for-loop for dim[0]
    augs = list()
    for sample_id in range(data.shape[0]):
        coords = create_zero_centered_coordinate_mesh(patch_size)

        # now find a nice center location and extract patch
        if seg is None:
            patch_type = 'any'

        handler = 0
        n = 0
        while handler == 0:

            # augmentation
            modified_coords = False
            if np.random.uniform() < p_el_per_sample and do_elastic_deform:
                a = np.random.uniform(alpha[0], alpha[1])
                s = np.random.uniform(sigma[0], sigma[1])
                coords = elastic_deform_coordinates(coords, a, s)
                modified_coords = True

                augs.append('elastic')

            if np.random.uniform() < p_rot_per_sample and do_rotation:
                if angle_x[0] == angle_x[1]:
                    a_x = angle_x[0]
                else:
                    a_x = np.random.uniform(angle_x[0], angle_x[1])
                if dim == 3:
                    if angle_y[0] == angle_y[1]:
                        a_y = angle_y[0]
                    else:
                        a_y = np.random.uniform(angle_y[0], angle_y[1])
                    if angle_z[0] == angle_z[1]:
                        a_z = angle_z[0]
                    else:
                        a_z = np.random.uniform(angle_z[0], angle_z[1])
                    coords = rotate_coords_3d(coords, a_x, a_y, a_z)
                else:
                    coords = rotate_coords_2d(coords, a_x)
                modified_coords = True

                augs.append('rotation')

            if np.random.uniform() < p_scale_per_sample and do_scale:
                if np.random.random() < 0.5 and scale[0] < 1:
                    sc = np.random.uniform(scale[0], 1)
                else:
                    sc = np.random.uniform(max(scale[0], 1), scale[1])
                coords = scale_coords(coords, sc)
                modified_coords = True

                augs.append('scale')

            # find candidate area for center, the area is cand_point_coord +/- patch_size
            if patch_type in ['fore', 'small'] and seg is not None:
                if seg.shape[1] > 1:
                    logger.error('TBD for seg with multiple channels')
                if patch_type == 'fore':
                    lab_coords = np.where(
                        seg[sample_id, 0, ...] > 0)  # lab_coords: tuple
                elif patch_type == 'small':
                    if config_task.task == 'Task05_Prostate':
                        lab_coords = np.where(seg[sample_id, 0, ...] == 1)
                    else:
                        lab_coords = np.where(
                            seg[sample_id, 0,
                                ...] == config_task.num_class - 1)
                if len(lab_coords[0]) > 0:  # 0 means no such label exists
                    idx = np.random.choice(len(lab_coords[0]))
                    cand_point_coord = [
                        coords[idx] for coords in lab_coords
                    ]  # coords for one random point from 'fore' ground
                else:
                    cand_point_coord = None

            if patch_type in ['fore', 'small'] and cand_point_coord is None:
                ctr_list = None
                handler = 1
                data_result = None
                seg_result = None
                augs = None
            else:
                ctr_list = list()  # coords of the patch center
                for d in range(dim):
                    if random_crop:
                        if patch_type in ['fore', 'small'] and seg is not None:
                            low = max(
                                patch_center_dist_from_border[d] - 1,
                                cand_point_coord[d] - (patch_size[d] / 2 - 1))
                            low = int(low)
                            upper = min(
                                cand_point_coord[d] + (patch_size[d] / 2 - 1),
                                data.shape[d + 2] -
                                (patch_center_dist_from_border[d] - 1)
                            )  # +/- patch_size[d] is better but computation costly
                            upper = int(upper)

                            if low == upper:
                                ctr = int(low)
                            elif low < upper:
                                ctr = int(np.random.randint(low, upper))
                                # if n > 1:
                                #     logger.info('n:{}; [low,upper]:{}, ctr:{}'.format(n, str([low, upper]), ctr))
                            else:
                                logger.error(
                                    '(low:{} should be <= upper:{}). patch_type:{}, patch_center_dist_from_border:{}, cand_point_coord:{}, cand point seg value:{}, data.shape:{}, ctr_list:{}'
                                    .format(
                                        low, upper, str(patch_type),
                                        str(patch_center_dist_from_border),
                                        str(cand_point_coord),
                                        seg[sample_id, 0] + cand_point_coord,
                                        str(data.shape), str(ctr_list)))
                        elif patch_type == 'any':
                            if patch_center_dist_from_border[d] == data.shape[
                                    d + 2] - patch_center_dist_from_border[d]:
                                ctr = int(patch_center_dist_from_border[d])
                            elif patch_center_dist_from_border[d] < data.shape[
                                    d + 2] - patch_center_dist_from_border[d]:
                                ctr = int(
                                    np.random.randint(
                                        patch_center_dist_from_border[d],
                                        data.shape[d + 2] -
                                        patch_center_dist_from_border[d]))
                            else:
                                logger.error(
                                    'low should be <= upper. patch_type:{}, patch_center_dist_from_border:{}, data.shape:{}, ctr_list:{}'
                                    .format(str(patch_type),
                                            str(patch_center_dist_from_border),
                                            str(data.shape), str(ctr_list)))
                    else:  # center crop
                        ctr = int(np.round(data.shape[d + 2] / 2.))
                    ctr_list.append(ctr)

                # extracting patch
                if n < 10 and modified_coords:
                    for d in range(dim):
                        coords[d] += ctr_list[d]
                    for channel_id in range(data.shape[1]):
                        data_result[sample_id, channel_id] = interpolate_img(
                            data[sample_id, channel_id],
                            coords,
                            order_data,
                            border_mode_data,
                            cval=border_cval_data)
                    if seg is not None:
                        for channel_id in range(seg.shape[1]):
                            seg_result[sample_id,
                                       channel_id] = interpolate_img(
                                           seg[sample_id, channel_id],
                                           coords,
                                           order_seg,
                                           border_mode_seg,
                                           cval=border_cval_seg,
                                           is_seg=True)
                else:
                    augs = list()
                    if seg is None:
                        s = None
                    else:
                        s = seg[sample_id:sample_id + 1]
                    if random_crop:
                        # margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)]
                        # d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin)
                        d_tmps = list()
                        for channel_id in range(data.shape[1]):
                            d_tmp = utils.extract_roi_from_volume(
                                data[sample_id, channel_id],
                                ctr_list,
                                patch_size,
                                fill="zero")
                            d_tmps.append(d_tmp)
                        d = np.asarray(d_tmps)
                        if seg is not None:
                            s_tmps = list()
                            for channel_id in range(seg.shape[1]):
                                s_tmp = utils.extract_roi_from_volume(
                                    seg[sample_id, channel_id],
                                    ctr_list,
                                    patch_size,
                                    fill="zero")
                                s_tmps.append(s_tmp)
                            s = np.asarray(s_tmps)
                    else:
                        d, s = center_crop_aug(data[sample_id:sample_id + 1],
                                               patch_size, s)
                    # data_result[sample_id] = d[0]
                    data_result[sample_id] = d
                    if seg is not None:
                        # seg_result[sample_id] = s[0]
                        seg_result[sample_id] = s

                ## check patch
                if patch_type in [
                        'fore'
                ]:  # cancer could be very very small. so use opproximate method (i.e. use 'fore').
                    if np.any(seg_result > 0) and np.any(data_result != 0):
                        handler = 1
                    else:
                        handler = 0
                elif patch_type in ['small']:
                    if config_task.task == 'Task05_Prostate':
                        if np.any(seg_result == 1) and np.any(
                                data_result != 0):
                            handler = 1
                        else:
                            handler = 0
                    else:
                        if np.any(seg_result == config_task.num_class -
                                  1) and np.any(data_result != 0):
                            handler = 1
                        else:
                            handler = 0
                else:
                    if np.any(data_result != 0):
                        handler = 1
                    else:
                        handler = 0
                n += 1

                if n > 5:
                    logger.info(
                        'tag:{}, patch_type: {}; handler: {}; times: {}; cand point:{}; cand point seg value:{}; ctr_list:{}; data.shape:{}; np.unique(seg_result):{}; np.sum(data_result):{}'
                        .format(
                            tag, patch_type, handler, n, str(cand_point_coord),
                            seg[sample_id, 0, cand_point_coord[0],
                                cand_point_coord[1], cand_point_coord[2]],
                            str(ctr_list), str(data.shape),
                            np.unique(seg_result, return_counts=True),
                            np.sum(data_result)))

    return data_result, seg_result, augs
def affine_transformation(vol, radius, translate, scale, bspline_order,
                          border_mode, constant_val, is_reverse):
    """
    forward: scale -> rotation -> translation
    backward: translation -> rotation -> scale
    :param vol: np.ndarray or torch.tensor(batch, channel, h, w, d)
    :param radius: tuple
    :param translate: tuple
    :param scale: tuple
    :param bspline_order: 0~5 0:nearest, 1:bilinear, 2~5 bspline, 3 is common used.
    :param border_mode: "constant", "nearest"
    :param constant_val: contant value
    :param is_reverse: bool
    :return:
    """
    tensor_flag = False
    _device = torch.device("cpu")
    if isinstance(vol, torch.Tensor):
        _device = vol.device
        vol = vol.cpu().detach().numpy()
        tensor_flag = True
    shape = vol.shape[2:]
    dim = len(vol.shape) - 2
    center = tuple((i - 1) / 2. for i in shape)
    tmp = tuple([np.arange(i) for i in shape])
    coords = np.array(np.meshgrid(*tmp, indexing='ij')).astype(float)
    for d in range(len(shape)):
        coords[d] -= center[d]  # centered coords

    if is_reverse:
        # translation first
        for i in range(dim):
            coords[i] -= translate[i]
        # rotation
        rot_matrix = np.identity(len(coords))
        rot_matrix = create_matrix_rotation_z_3d(radius[2], rot_matrix)
        rot_matrix = create_matrix_rotation_y_3d(radius[1], rot_matrix)
        rot_matrix = create_matrix_rotation_x_3d(radius[0], rot_matrix)
        coords = np.dot(
            coords.reshape(len(coords), -1).transpose(),
            rot_matrix).transpose().reshape(coords.shape)
        # scale
        if isinstance(scale,
                      (tuple, list, np.ndarray)):  # scale axis individual
            assert len(scale) == len(coords)
            for i in range(len(scale)):
                coords[i] *= scale[i]
        else:
            coords *= scale  # scale axis both
    else:
        # scale first
        if isinstance(scale,
                      (tuple, list, np.ndarray)):  # scale axis individual
            assert len(scale) == len(coords)
            for i in range(len(scale)):
                coords[i] *= scale[i]
        else:
            coords *= scale  # scale axis both
        # rotation
        rot_matrix = np.identity(len(coords))
        rot_matrix = create_matrix_rotation_x_3d(radius[0], rot_matrix)
        rot_matrix = create_matrix_rotation_y_3d(radius[1], rot_matrix)
        rot_matrix = create_matrix_rotation_z_3d(radius[2], rot_matrix)
        coords = np.dot(
            coords.reshape(len(coords), -1).transpose(),
            rot_matrix).transpose().reshape(coords.shape)
        # translate
        for i in range(dim):
            coords[i] -= translate[i]

    for i in range(dim):
        coords[i] += center[i]
    data_result = np.zeros_like(vol, dtype=np.float32)
    for sample_id in range(vol.shape[0]):
        for channel_id in range(vol.shape[1]):
            data_result[sample_id,
                        channel_id] = interpolate_img(vol[sample_id,
                                                          channel_id],
                                                      coords,
                                                      order=bspline_order,
                                                      mode=border_mode,
                                                      cval=constant_val)
    if tensor_flag:
        data_result = torch.from_numpy(data_result)
        data_result = data_result.to(_device)
    return data_result
Exemple #10
0
def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30,
                    do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.),
                    do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi),
                    do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3,
                    border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True):
    dim = len(patch_size)
    seg_result = None
    if seg is not None:
        if dim == 2:
            seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32)
        else:
            seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]),
                                  dtype=np.float32)

    if dim == 2:
        data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32)
    else:
        data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]),
                               dtype=np.float32)

    if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]
    for sample_id in range(data.shape[0]):
        coords = create_zero_centered_coordinate_mesh(patch_size)
        if do_elastic_deform:
            a = np.random.uniform(alpha[0], alpha[1])
            s = np.random.uniform(sigma[0], sigma[1])
            coords = elastic_deform_coordinates(coords, a, s)
        if do_rotation:
            if angle_x[0] == angle_x[1]:
                a_x = angle_x[0]
            else:
                a_x = np.random.uniform(angle_x[0], angle_x[1])
            if dim == 3:
                if angle_y[0] == angle_y[1]:
                    a_y = angle_y[0]
                else:
                    a_y = np.random.uniform(angle_y[0], angle_y[1])
                if angle_z[0] == angle_z[1]:
                    a_z = angle_z[0]
                else:
                    a_z = np.random.uniform(angle_z[0], angle_z[1])
                coords = rotate_coords_3d(coords, a_x, a_y, a_z)
            else:
                coords = rotate_coords_2d(coords, a_x)
        if do_scale:
            if np.random.random() < 0.5 and scale[0] < 1:
                sc = np.random.uniform(scale[0], 1)
            else:
                sc = np.random.uniform(max(scale[0], 1), scale[1])
            coords = scale_coords(coords, sc)
        # now find a nice center location
        for d in range(dim):
            if random_crop:
                ctr = np.random.uniform(patch_center_dist_from_border[d],
                                        data.shape[d + 2] - patch_center_dist_from_border[d])
            else:
                ctr = int(np.round(data.shape[d + 2] / 2.))
            coords[d] += ctr
        for channel_id in range(data.shape[1]):
            data_result[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
                                                                 border_mode_data, cval=border_cval_data)
        if seg is not None:
            for channel_id in range(seg.shape[1]):
                seg_result[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
                                                                    border_mode_seg, cval=border_cval_seg, is_seg=True)
    return data_result, seg_result
def augment_spatial(data,
                    seg,
                    patch_size,
                    patch_center_dist_from_border=30,
                    do_elastic_deform=True,
                    alpha=(0., 1000.),
                    sigma=(10., 13.),
                    do_rotation=True,
                    angle_x=(0, 2 * np.pi),
                    angle_y=(0, 2 * np.pi),
                    angle_z=(0, 2 * np.pi),
                    do_scale=True,
                    scale=(0.75, 1.25),
                    border_mode_data='nearest',
                    border_cval_data=0,
                    order_data=3,
                    border_mode_seg='constant',
                    border_cval_seg=0,
                    order_seg=0,
                    random_crop=True,
                    p_el_per_sample=1,
                    p_scale_per_sample=1,
                    p_rot_per_sample=1,
                    independent_scale_for_each_axis=False,
                    p_rot_per_axis: float = 1):
    dim = len(patch_size)
    seg_result = None
    if seg is not None:
        if dim == 2:
            seg_result = np.zeros(
                (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]),
                dtype=np.float32)
        else:
            seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0],
                                   patch_size[1], patch_size[2]),
                                  dtype=np.float32)

    if dim == 2:
        data_result = np.zeros(
            (data.shape[0], data.shape[1], patch_size[0], patch_size[1]),
            dtype=np.float32)
    else:
        data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0],
                                patch_size[1], patch_size[2]),
                               dtype=np.float32)

    if not isinstance(patch_center_dist_from_border,
                      (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]

    for sample_id in range(data.shape[0]):
        coords = create_zero_centered_coordinate_mesh(patch_size)
        modified_coords = False

        if do_elastic_deform and np.random.uniform() < p_el_per_sample:
            a = np.random.uniform(alpha[0], alpha[1])
            s = np.random.uniform(sigma[0], sigma[1])
            coords = elastic_deform_coordinates(coords, a, s)
            modified_coords = True

        if do_rotation and np.random.uniform() < p_rot_per_sample:

            if np.random.uniform() <= p_rot_per_axis:
                a_x = np.random.uniform(angle_x[0], angle_x[1])
            else:
                a_x = 0

            if dim == 3:
                if np.random.uniform() <= p_rot_per_axis:
                    a_y = np.random.uniform(angle_y[0], angle_y[1])
                else:
                    a_y = 0

                if np.random.uniform() <= p_rot_per_axis:
                    a_z = np.random.uniform(angle_z[0], angle_z[1])
                else:
                    a_z = 0

                coords = rotate_coords_3d(coords, a_x, a_y, a_z)
            else:
                coords = rotate_coords_2d(coords, a_x)
            modified_coords = True

        if do_scale and np.random.uniform() < p_scale_per_sample:
            if not independent_scale_for_each_axis:
                if np.random.random() < 0.5 and scale[0] < 1:
                    sc = np.random.uniform(scale[0], 1)
                else:
                    sc = np.random.uniform(max(scale[0], 1), scale[1])
            else:
                sc = []
                for _ in range(dim):
                    if np.random.random() < 0.5 and scale[0] < 1:
                        sc.append(np.random.uniform(scale[0], 1))
                    else:
                        sc.append(np.random.uniform(max(scale[0], 1),
                                                    scale[1]))
            coords = scale_coords(coords, sc)
            modified_coords = True

        # now find a nice center location
        if modified_coords:
            for d in range(dim):
                if random_crop:
                    ctr = np.random.uniform(
                        patch_center_dist_from_border[d],
                        data.shape[d + 2] - patch_center_dist_from_border[d])
                else:
                    ctr = int(np.round(data.shape[d + 2] / 2.))
                coords[d] += ctr
            for channel_id in range(data.shape[1]):
                data_result[sample_id, channel_id] = interpolate_img(
                    data[sample_id, channel_id],
                    coords,
                    order_data,
                    border_mode_data,
                    cval=border_cval_data)
            if seg is not None:
                for channel_id in range(seg.shape[1]):
                    seg_result[sample_id, channel_id] = interpolate_img(
                        seg[sample_id, channel_id],
                        coords,
                        order_seg,
                        border_mode_seg,
                        cval=border_cval_seg,
                        is_seg=True)
        else:
            # valid values: constant/edge/median/minimum/reflect/symmetric/etc
            if border_mode_data == 'nearest':
                border_pad_mode_data = 'edge'
                pad_kwargs_data = {}
            elif border_mode_data == 'constant':
                border_pad_mode_data = border_mode_data
                pad_kwargs_data = {'constant_values': border_cval_data}
            if border_mode_seg == 'nearest':
                border_pad_mode_seg = 'edge'
                pad_kwargs_seg = {}
            elif border_mode_seg == 'constant':
                border_pad_mode_seg = border_mode_seg
                pad_kwargs_seg = {'constant_values': border_cval_seg}

            if seg is None:
                s = None
            else:
                s = seg[sample_id:sample_id + 1]
            if random_crop:
                margin = [
                    patch_center_dist_from_border[d] - patch_size[d] // 2
                    for d in range(dim)
                ]
                d, s = random_crop_aug(data[sample_id:sample_id + 1],
                                       s,
                                       patch_size,
                                       margin,
                                       pad_mode=border_pad_mode_data,
                                       pad_kwargs=pad_kwargs_data,
                                       pad_mode_seg=border_pad_mode_seg,
                                       pad_kwargs_seg=pad_kwargs_seg)
            else:
                d, s = center_crop_aug(data[sample_id:sample_id + 1],
                                       patch_size,
                                       s,
                                       pad_mode=border_pad_mode_data,
                                       pad_kwargs=pad_kwargs_data,
                                       pad_mode_seg=border_pad_mode_seg,
                                       pad_kwargs_seg=pad_kwargs_seg)
            data_result[sample_id] = d[0]
            if seg is not None:
                seg_result[sample_id] = s[0]
    return data_result, seg_result
def augment_spatial(data,
                    seg,
                    patch_size,
                    patch_center_dist_from_border=30,
                    do_elastic_deform=True,
                    alpha=(0., 1000.),
                    sigma=(10., 13.),
                    do_rotation=True,
                    angle_x=(0, 2 * np.pi),
                    angle_y=(0, 2 * np.pi),
                    angle_z=(0, 2 * np.pi),
                    do_scale=True,
                    scale=(0.75, 1.25),
                    border_mode_data='nearest',
                    border_cval_data=0,
                    order_data=3,
                    border_mode_seg='constant',
                    border_cval_seg=0,
                    order_seg=0,
                    crop_mode='random',
                    p_el_per_sample=1,
                    p_scale_per_sample=1,
                    p_rot_per_sample=1):
    dim = len(patch_size)
    seg_result = None
    if seg is not None:
        if dim == 2:
            seg_result = np.zeros(
                (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]),
                dtype=np.float32)
        else:
            seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0],
                                   patch_size[1], patch_size[2]),
                                  dtype=np.float32)

    if dim == 2:
        data_result = np.zeros(
            (data.shape[0], data.shape[1], patch_size[0], patch_size[1]),
            dtype=np.float32)
    else:
        data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0],
                                patch_size[1], patch_size[2]),
                               dtype=np.float32)

    if not isinstance(patch_center_dist_from_border,
                      (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]

    for sample_id in range(data.shape[0]):
        coords = create_zero_centered_coordinate_mesh(patch_size)
        modified_coords = False

        if np.random.uniform() < p_el_per_sample and do_elastic_deform:
            a = np.random.uniform(alpha[0], alpha[1])
            s = np.random.uniform(sigma[0], sigma[1])
            coords = elastic_deform_coordinates(coords, a, s)
            modified_coords = True

        if np.random.uniform() < p_rot_per_sample and do_rotation:
            if angle_x[0] == angle_x[1]:
                a_x = angle_x[0]
            else:
                a_x = np.random.uniform(angle_x[0], angle_x[1])
            if dim == 3:
                if angle_y[0] == angle_y[1]:
                    a_y = angle_y[0]
                else:
                    a_y = np.random.uniform(angle_y[0], angle_y[1])
                if angle_z[0] == angle_z[1]:
                    a_z = angle_z[0]
                else:
                    a_z = np.random.uniform(angle_z[0], angle_z[1])
                coords = rotate_coords_3d(coords, a_x, a_y, a_z)
            else:
                coords = rotate_coords_2d(coords, a_x)
            modified_coords = True

        if np.random.uniform() < p_scale_per_sample and do_scale:
            if np.random.random() < 0.5 and scale[0] < 1:
                sc = np.random.uniform(scale[0], 1)
            else:
                sc = np.random.uniform(max(scale[0], 1), scale[1])
            coords = scale_coords(coords, sc)
            modified_coords = True

        # now find a nice center location
        if modified_coords:
            if crop_mode == 'roi':
                lbs = get_lbls_for_roi_crop(patch_size, data.shape,
                                            seg[sample_id, 0])

            for d in range(dim):
                if crop_mode == 'roi':
                    ctr = int(lbs[d] + np.round(patch_size[d] / 2.))
                elif crop_mode == 'center':
                    ctr = int(np.round(data.shape[d + 2] / 2.))
                else:
                    ctr = np.random.uniform(
                        patch_center_dist_from_border[d],
                        data.shape[d + 2] - patch_center_dist_from_border[d])

                coords[d] += ctr
            for channel_id in range(data.shape[1]):
                data_result[sample_id, channel_id] = interpolate_img(
                    data[sample_id, channel_id],
                    coords,
                    order_data,
                    border_mode_data,
                    cval=border_cval_data)
            if seg is not None:
                for channel_id in range(seg.shape[1]):
                    seg_result[sample_id, channel_id] = interpolate_img(
                        seg[sample_id, channel_id],
                        coords,
                        order_seg,
                        border_mode_seg,
                        cval=border_cval_seg,
                        is_seg=True)
        else:
            if seg is None:
                s = None
            else:
                s = seg[sample_id:sample_id + 1]

            if crop_mode == 'roi':
                d, s = crop(data[sample_id:sample_id + 1],
                            s,
                            patch_size,
                            crop_type='roi')
            elif crop_mode == 'center':
                d, s = center_crop_aug(data[sample_id:sample_id + 1],
                                       patch_size, s)
            else:  #crop_mode ==  'random'
                margin = [
                    patch_center_dist_from_border[d] - patch_size[d] // 2
                    for d in range(dim)
                ]
                d, s = random_crop_aug(data[sample_id:sample_id + 1], s,
                                       patch_size, margin)

            data_result[sample_id] = d[0]
            if seg is not None:
                seg_result[sample_id] = s[0]

    return data_result, seg_result
def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30,
                    do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.),
                    do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi),
                    do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=1,
                    border_mode_seg='constant', border_cval_seg=0, order_seg=0, fraction=0.0, spacing=(1,1,1)):
    dim = len(patch_size)
    seg_result = None
    while len(data.shape) <5:
        data = data[np.newaxis,:]
    while len(seg.shape) <5:
        seg = seg[np.newaxis, :]

    if seg is not None:
        if dim == 2:
            seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32)
        else:
            seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]),
                                  dtype=np.float32)

    if dim == 2:
        data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32)
    else:
        data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]),
                               dtype=np.float32)

    if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]
    for sample_id in range(data.shape[0]):
        coords = create_zero_centered_coordinate_mesh(patch_size)
        if do_elastic_deform:
            a_base = np.repeat(np.random.uniform(alpha[0], alpha[1]),3)
            a = a_base / np.array(spacing)
            s = np.repeat(np.random.uniform(sigma[0], sigma[1]), 3)
            coords = elastic_deform_coordinates(coords, a, s)
        if do_rotation:
            if angle_x[0] == angle_x[1]:
                a_x = angle_x[0]
            else:
                a_x = np.random.uniform(angle_x[0], angle_x[1])
            if dim == 3:
                if angle_y[0] == angle_y[1]:
                    a_y = angle_y[0]
                else:
                    a_y = np.random.uniform(angle_y[0], angle_y[1])
                if angle_z[0] == angle_z[1]:
                    a_z = angle_z[0]
                else:
                    a_z = np.random.uniform(angle_z[0], angle_z[1])
                coords = rotate_coords_3d(coords, a_x, a_y, a_z)
            else:
                coords = rotate_coords_2d(coords, a_x)
        if do_scale:
            if np.random.random() < 0.5 and scale[0] < 1:
                sc = np.random.uniform(scale[0], 1)
            else:
                sc = np.random.uniform(max(scale[0], 1), scale[1])
            coords = scale_coords(coords, sc)
        # now find a nice center location
        ctr_obj = np.where(seg[0,1:] > 0)[1:]
        ctr_obj = np.array([np.mean(i) for i in ctr_obj])
        ctr_obj = np.random.uniform(0.8, 1.2) * ctr_obj
        ctr_obj = np.clip(ctr_obj, patch_size, data.shape[-3:] - patch_size)
        rval = np.random.uniform(0.0,1.0)
        for d in range(dim):
            if rval > fraction:
                ctr = np.random.uniform(patch_center_dist_from_border[d],
                                        data.shape[d + 2] - patch_center_dist_from_border[d])
            #elif crop > 'center':
            #    ctr = int(np.round(data.shape[d + 2] / 2.))
            elif rval < fraction:
                ctr = ctr_obj[d] 
            coords[d] += ctr
        for channel_id in range(data.shape[1]):
            data_result[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
                                                                 border_mode_data, cval=border_cval_data)
        if seg is not None:
            for channel_id in range(seg.shape[1]):
                seg_result[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
                                                                    border_mode_seg, cval=border_cval_seg)
    return data_result[0], seg_result[0], rval