def augment_spatial_peaks(data, seg, patch_size, patch_center_dist_from_border=30, do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1, slice_dir=None): dim = len(patch_size) seg_result = None if seg is not None: if dim == 2: seg_result = np.zeros( (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if dim == 2: data_result = np.zeros( (data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) modified_coords = False if np.random.uniform() < p_el_per_sample and do_elastic_deform: a = np.random.uniform(alpha[0], alpha[1]) s = np.random.uniform(sigma[0], sigma[1]) coords = elastic_deform_coordinates(coords, a, s) modified_coords = True # NEW: initialize all because all needed for rotate_multiple_peaks (even if only rotating along one axis) a_x = 0 a_y = 0 a_z = 0 if np.random.uniform() < p_rot_per_sample and do_rotation: if angle_x[0] == angle_x[1]: a_x = angle_x[0] else: a_x = np.random.uniform(angle_x[0], angle_x[1]) if dim == 3: if angle_y[0] == angle_y[1]: a_y = angle_y[0] else: a_y = np.random.uniform(angle_y[0], angle_y[1]) if angle_z[0] == angle_z[1]: a_z = angle_z[0] else: a_z = np.random.uniform(angle_z[0], angle_z[1]) coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) modified_coords = True if np.random.uniform() < p_scale_per_sample and do_scale: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) coords = scale_coords(coords, sc) modified_coords = True # now find a nice center location if modified_coords: for d in range(dim): if random_crop: ctr = np.random.uniform( patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d]) else: ctr = int(np.round(data.shape[d + 2] / 2.)) coords[d] += ctr for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img( data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[1]): seg_result[sample_id, channel_id] = interpolate_img( seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) else: if seg is None: s = None else: s = seg[sample_id:sample_id + 1] if random_crop: margin = [ patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim) ] d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin) else: d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s) data_result[sample_id] = d[0] if seg is not None: seg_result[sample_id] = s[0] # NEW: Rotate Peaks / Tensors if dim > 2: raise ValueError( "augment_spatial_peaks only supports 2D at the moment") sampled_2D_angle = a_x # if 2D angle will always be a_x even if rotating other axis a_x = 0 a_y = 0 a_z = 0 if slice_dir == 0: a_x = sampled_2D_angle elif slice_dir == 1: a_y = sampled_2D_angle elif slice_dir == 2: # Somehow we have to invert rotation direction for z to make align properly with rotated voxels. # Unclear why this is the case. Maybe some different conventions for peaks and voxels?? a_z = sampled_2D_angle * -1 else: raise ValueError("invalid slice_dir passed as argument") data_aug = data_result[sample_id] if data_aug.shape[0] == 9: data_result[sample_id] = rotate_multiple_peaks( data_aug, a_x, a_y, a_z) elif data_aug.shape[0] == 18: data_result[sample_id] = rotate_multiple_tensors( data_aug, a_x, a_y, a_z) else: raise ValueError("Incorrect number of channels (expected 9 or 18)") return data_result, seg_result
def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1, independent_scale_for_each_axis=False, p_rot_per_axis: float = 1, p_independent_scale_per_axis: int = 1): dim = len(patch_size) seg_result = None if seg is not None: if dim == 2: seg_result = np.zeros( (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if dim == 2: data_result = np.zeros( (data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) modified_coords = False if do_elastic_deform and np.random.uniform() < p_el_per_sample: a = np.random.uniform(alpha[0], alpha[1]) s = np.random.uniform(sigma[0], sigma[1]) coords = elastic_deform_coordinates(coords, a, s) modified_coords = True if do_rotation and np.random.uniform() < p_rot_per_sample: if np.random.uniform() <= p_rot_per_axis: a_x = np.random.uniform(angle_x[0], angle_x[1]) else: a_x = 0 if dim == 3: if np.random.uniform() <= p_rot_per_axis: a_y = np.random.uniform(angle_y[0], angle_y[1]) else: a_y = 0 if np.random.uniform() <= p_rot_per_axis: a_z = np.random.uniform(angle_z[0], angle_z[1]) else: a_z = 0 coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) modified_coords = True if do_scale and np.random.uniform() < p_scale_per_sample: if independent_scale_for_each_axis and np.random.uniform( ) < p_independent_scale_per_axis: sc = [] for _ in range(dim): if np.random.random() < 0.5 and scale[0] < 1: sc.append(np.random.uniform(scale[0], 1)) else: sc.append(np.random.uniform(max(scale[0], 1), scale[1])) else: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) coords = scale_coords(coords, sc) modified_coords = True # now find a nice center location if modified_coords: for d in range(dim): if random_crop: ctr = np.random.uniform( patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d]) else: ctr = int(np.round(data.shape[d + 2] / 2.)) coords[d] += ctr for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img( data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[1]): seg_result[sample_id, channel_id] = interpolate_img( seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) else: if seg is None: s = None else: s = seg[sample_id:sample_id + 1] if random_crop: margin = [ patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim) ] d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin) else: d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s) data_result[sample_id] = d[0] if seg is not None: seg_result[sample_id] = s[0] return data_result, seg_result
def augment_spatial_2(data, seg, patch_size, patch_center_dist_from_border=30, do_elastic_deform=True, deformation_scale=(0, 0.25), do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1, independent_scale_for_each_axis=False, p_rot_per_axis: float = 1, p_independent_scale_per_axis: int = 1): """ :param data: :param seg: :param patch_size: :param patch_center_dist_from_border: :param do_elastic_deform: :param magnitude: this determines how large the magnitude of the deformation is relative to the patch_size. 0.125 = 12.5%% of the patch size (in each dimension). :param sigma: this determines the scale of the deformation. small values = local deformations, large values = large deformations. :param do_rotation: :param angle_x: :param angle_y: :param angle_z: :param do_scale: :param scale: :param border_mode_data: :param border_cval_data: :param order_data: :param border_mode_seg: :param border_cval_seg: :param order_seg: :param random_crop: :param p_el_per_sample: :param p_scale_per_sample: :param p_rot_per_sample: :param clip_to_safe_magnitude: :return: """ dim = len(patch_size) seg_result = None if seg is not None: if dim == 2: seg_result = np.zeros( (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if dim == 2: data_result = np.zeros( (data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) modified_coords = False if np.random.uniform() < p_el_per_sample and do_elastic_deform: mag = [] sigmas = [] # one scale per case, scale is in percent of patch_size def_scale = np.random.uniform(deformation_scale[0], deformation_scale[1]) for d in range(len(data[sample_id].shape) - 1): # transform relative def_scale in pixels sigmas.append(def_scale * patch_size[d]) # define max magnitude and min_magnitude max_magnitude = sigmas[-1] * (1 / 2) min_magnitude = sigmas[-1] * (1 / 8) # the magnitude needs to depend on the scale, otherwise not much is going to happen most of the time. # we want the magnitude to be high, but not higher than max_magnitude (otherwise the deformations # become very ugly). Let's sample mag_real with a gaussian # mag_real = np.random.normal(max_magnitude * (2 / 3), scale=max_magnitude / 3) # clip to make sure we stay reasonable # mag_real = np.clip(mag_real, 0, max_magnitude) mag_real = np.random.uniform(min_magnitude, max_magnitude) mag.append(mag_real) # print(np.round(sigmas, decimals=3), np.round(mag, decimals=3)) coords = elastic_deform_coordinates_2(coords, sigmas, mag) modified_coords = True if do_rotation and np.random.uniform() < p_rot_per_sample: if np.random.uniform() <= p_rot_per_axis: a_x = np.random.uniform(angle_x[0], angle_x[1]) else: a_x = 0 if dim == 3: if np.random.uniform() <= p_rot_per_axis: a_y = np.random.uniform(angle_y[0], angle_y[1]) else: a_y = 0 if np.random.uniform() <= p_rot_per_axis: a_z = np.random.uniform(angle_z[0], angle_z[1]) else: a_z = 0 coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) modified_coords = True if do_scale and np.random.uniform() < p_scale_per_sample: if independent_scale_for_each_axis and np.random.uniform( ) < p_independent_scale_per_axis: sc = [] for _ in range(dim): if np.random.random() < 0.5 and scale[0] < 1: sc.append(np.random.uniform(scale[0], 1)) else: sc.append(np.random.uniform(max(scale[0], 1), scale[1])) else: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) coords = scale_coords(coords, sc) modified_coords = True # now find a nice center location if modified_coords: # recenter coordinates coords_mean = coords.mean(axis=tuple(range(1, len(coords.shape))), keepdims=True) coords -= coords_mean for d in range(dim): if random_crop: ctr = np.random.uniform( patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d]) else: ctr = int(np.round(data.shape[d + 2] / 2.)) coords[d] += ctr for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img( data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[1]): seg_result[sample_id, channel_id] = interpolate_img( seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) else: if seg is None: s = None else: s = seg[sample_id:sample_id + 1] if random_crop: margin = [ patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim) ] d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin) else: d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s) data_result[sample_id] = d[0] if seg is not None: seg_result[sample_id] = s[0] return data_result, seg_result
def augment_spatial_transform(data, seg=None, patch_size=None, patch_center_dist_from_border=None, do_elastic_deform=True, alpha=(0., 900.), sigma=(9., 13.), do_rotation=True, angle_x=default_3D_augmentation_params.get("rotation_x"), angle_y=default_3D_augmentation_params.get("rotation_y"), angle_z=default_3D_augmentation_params.get("rotation_z"), do_scale=True, scale=(0.85, 1.25), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=-1, order_seg=1, random_crop=False, p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1): if len(data.shape) == 3: print("Attention: data shape must be (C, D, H, W). I will transform it.") data = data[None] if seg is not None and len(seg.shape) == 3: seg = seg[None] if patch_size is None: patch_size = (data.shape[1], data.shape[2], data.shape[3]) dim = len(patch_size) seg_result = None if seg is not None: if dim == 2: seg_result = np.zeros((seg.shape[0], patch_size[0], patch_size[1]), dtype=np.float32) else: seg_result = np.zeros((seg.shape[0], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if dim == 2: data_result = np.zeros((data.shape[0], patch_size[0], patch_size[1]), dtype=np.float32) else: data_result = np.zeros((data.shape[0], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] coords = create_zero_centered_coordinate_mesh(patch_size) modified_coords = False if np.random.uniform() < p_el_per_sample and do_elastic_deform: a = np.random.uniform(alpha[0], alpha[1]) s = np.random.uniform(sigma[0], sigma[1]) coords = elastic_deform_coordinates(coords, a, s) modified_coords = True if np.random.uniform() < p_rot_per_sample and do_rotation: if angle_x[0] == angle_x[1]: a_x = angle_x[0] else: a_x = np.random.uniform(angle_x[0], angle_x[1]) if dim == 3: if angle_y[0] == angle_y[1]: a_y = angle_y[0] else: a_y = np.random.uniform(angle_y[0], angle_y[1]) if angle_z[0] == angle_z[1]: a_z = angle_z[0] else: a_z = np.random.uniform(angle_z[0], angle_z[1]) coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) modified_coords = True if np.random.uniform() < p_scale_per_sample and do_scale: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) coords = scale_coords(coords, sc) modified_coords = True # now find a nice center location if modified_coords: for d in range(dim): if random_crop: ctr = np.random.uniform(patch_center_dist_from_border[d], data.shape[d + 1] - patch_center_dist_from_border[d]) else: ctr = int(np.around(data.shape[d + 1] / 2.)) coords[d] += ctr for channel_id in range(data.shape[0]): data_result[channel_id] = interpolate_img(data[channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[0]): seg_result[channel_id] = interpolate_img(seg[channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) else: if seg is None: s = None else: s = seg[None] if random_crop: margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)] d, s = random_crop_aug(data[None], s, patch_size, margin) else: d, s = center_crop_aug(data[None], patch_size, s) data_result = d[0] if seg is not None: seg_result = s[0] return data_result, seg_result