def rotate(self, data_dict): data = data_dict["data"] do_seg = False seg = None if "seg" in list(data_dict.keys()): seg = data_dict["seg"] do_seg = True shape = np.array(data.shape[2:]) dim = len(shape) for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(shape) if dim == 3: coords = rotate_coords_3d(coords, self.rand_params['ax'][sample_id], self.rand_params['ay'][sample_id], self.rand_params['az'][sample_id]) else: coords = rotate_coords_2d(coords, self.rand_params['ax'][sample_id]) coords = uncenter_coords(coords) for channel_id in range(data.shape[1]): data[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, self.params['order_data'], self.params['bmode_data'], cval=self.params['bcval_data']) if do_seg: for channel_id in range(seg.shape[1]): seg[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, self.params['order_seg'], self.params['bmode_seg'], cval=self.params['bcval_seg'], is_seg=True) return {'data': data, 'seg': seg}
def batch_deform_3d(vol, coords, bspline_order, border_mode, constant_val): data_result = np.zeros_like(vol, dtype=np.float32) for sample_id in range(vol.shape[0]): for channel_id in range(vol.shape[1]): data_result[sample_id, channel_id] = interpolate_img(vol[sample_id, channel_id], coords, order=bspline_order, mode=border_mode, cval=constant_val) return data_result
def augment_spatial(self, data, coords, is_label=False): if is_label: order = self.order_seg border_mode = self.border_mode_seg border_cval = self.border_cval_seg else: order = self.order_data border_mode = self.border_mode_data border_cval = self.border_cval_data data = interpolate_img(data, coords, order, border_mode, cval=border_cval) return data
def augment_spatial(data, seg, patch_size, labels_extra=None, patch_center_dist_from_border=30, do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1, independent_scale_for_each_axis=False, p_rot_per_axis: float = 1): dim = len(patch_size) seg_result = None if seg is not None: if dim == 2: seg_result = np.zeros( (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) # shuailin: extra segs seg_extra_outputs = [] if labels_extra is not None: for _ in range(len(labels_extra)): if dim == 2: seg_extra_outputs.append( np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32)) else: seg_extra_outputs.append( np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32)) if dim == 2: data_result = np.zeros( (data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) modified_coords = False if do_elastic_deform and np.random.uniform() < p_el_per_sample: a = np.random.uniform(alpha[0], alpha[1]) s = np.random.uniform(sigma[0], sigma[1]) coords = elastic_deform_coordinates(coords, a, s) modified_coords = True if do_rotation and np.random.uniform() < p_rot_per_sample: if np.random.uniform() <= p_rot_per_axis: a_x = np.random.uniform(angle_x[0], angle_x[1]) else: a_x = 0 if dim == 3: if np.random.uniform() <= p_rot_per_axis: a_y = np.random.uniform(angle_y[0], angle_y[1]) else: a_y = 0 if np.random.uniform() <= p_rot_per_axis: a_z = np.random.uniform(angle_z[0], angle_z[1]) else: a_z = 0 coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) modified_coords = True if do_scale and np.random.uniform() < p_scale_per_sample: if not independent_scale_for_each_axis: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) else: sc = [] for _ in range(dim): if np.random.random() < 0.5 and scale[0] < 1: sc.append(np.random.uniform(scale[0], 1)) else: sc.append(np.random.uniform(max(scale[0], 1), scale[1])) coords = scale_coords(coords, sc) modified_coords = True # now find a nice center location if modified_coords: for d in range(dim): if random_crop: ctr = np.random.uniform( patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d]) else: ctr = int(np.round(data.shape[d + 2] / 2.)) coords[d] += ctr for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img( data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[1]): seg_result[sample_id, channel_id] = interpolate_img( seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) # shuailin: extra segs transformation if labels_extra is not None: for ith, seg_arr in enumerate(labels_extra): for channel_id in range(seg_arr.shape[1]): seg_extra_outputs[ith][sample_id, channel_id] = interpolate_img( seg_arr[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) else: if seg is None: s = None else: s = seg[sample_id:sample_id + 1] # shuailin extension if labels_extra is not None: labels_extra_sample = [] for ith in range(len(labels_extra)): labels_extra_sample.append( labels_extra[ith][sample_id:sample_id + 1]) else: labels_extra_sample = None if random_crop: margin = [ patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim) ] d, s, seg_extra_output = random_crop_aug( data[sample_id:sample_id + 1], s, patch_size, margin, labels_extra=labels_extra_sample) else: d, s, seg_extra_output = center_crop_aug( data[sample_id:sample_id + 1], patch_size, s, labels_extra=labels_extra_sample) data_result[sample_id] = d[0] if seg is not None: seg_result[sample_id] = s[0] # shuailin extension if labels_extra is not None: for ith in range(len(labels_extra)): seg_extra_outputs[ith][sample_id] = seg_extra_output[ith] return data_result, seg_result, seg_extra_outputs
def augment_spatial_peaks(data, seg, patch_size, patch_center_dist_from_border=30, do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1, slice_dir=None): dim = len(patch_size) seg_result = None if seg is not None: if dim == 2: seg_result = np.zeros( (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if dim == 2: data_result = np.zeros( (data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) modified_coords = False if np.random.uniform() < p_el_per_sample and do_elastic_deform: a = np.random.uniform(alpha[0], alpha[1]) s = np.random.uniform(sigma[0], sigma[1]) coords = elastic_deform_coordinates(coords, a, s) modified_coords = True # NEW: initialize all because all needed for rotate_multiple_peaks (even if only rotating along one axis) a_x = 0 a_y = 0 a_z = 0 if np.random.uniform() < p_rot_per_sample and do_rotation: if angle_x[0] == angle_x[1]: a_x = angle_x[0] else: a_x = np.random.uniform(angle_x[0], angle_x[1]) if dim == 3: if angle_y[0] == angle_y[1]: a_y = angle_y[0] else: a_y = np.random.uniform(angle_y[0], angle_y[1]) if angle_z[0] == angle_z[1]: a_z = angle_z[0] else: a_z = np.random.uniform(angle_z[0], angle_z[1]) coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) modified_coords = True if np.random.uniform() < p_scale_per_sample and do_scale: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) coords = scale_coords(coords, sc) modified_coords = True # now find a nice center location if modified_coords: for d in range(dim): if random_crop: ctr = np.random.uniform( patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d]) else: ctr = int(np.round(data.shape[d + 2] / 2.)) coords[d] += ctr for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img( data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[1]): seg_result[sample_id, channel_id] = interpolate_img( seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) else: if seg is None: s = None else: s = seg[sample_id:sample_id + 1] if random_crop: margin = [ patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim) ] d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin) else: d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s) data_result[sample_id] = d[0] if seg is not None: seg_result[sample_id] = s[0] # NEW: Rotate Peaks / Tensors if dim > 2: raise ValueError( "augment_spatial_peaks only supports 2D at the moment") sampled_2D_angle = a_x # if 2D angle will always be a_x even if rotating other axis a_x = 0 a_y = 0 a_z = 0 if slice_dir == 0: a_x = sampled_2D_angle elif slice_dir == 1: a_y = sampled_2D_angle elif slice_dir == 2: # Somehow we have to invert rotation direction for z to make align properly with rotated voxels. # Unclear why this is the case. Maybe some different conventions for peaks and voxels?? a_z = sampled_2D_angle * -1 else: raise ValueError("invalid slice_dir passed as argument") data_aug = data_result[sample_id] if data_aug.shape[0] == 9: data_result[sample_id] = rotate_multiple_peaks( data_aug, a_x, a_y, a_z) elif data_aug.shape[0] == 18: data_result[sample_id] = rotate_multiple_tensors( data_aug, a_x, a_y, a_z) else: raise ValueError("Incorrect number of channels (expected 9 or 18)") return data_result, seg_result
def augment_spatial_2(data, seg, patch_size, patch_center_dist_from_border=30, do_elastic_deform=True, deformation_scale=(0, 0.25), do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1, independent_scale_for_each_axis=False, p_rot_per_axis: float = 1, p_independent_scale_per_axis: int = 1): """ :param data: :param seg: :param patch_size: :param patch_center_dist_from_border: :param do_elastic_deform: :param magnitude: this determines how large the magnitude of the deformation is relative to the patch_size. 0.125 = 12.5%% of the patch size (in each dimension). :param sigma: this determines the scale of the deformation. small values = local deformations, large values = large deformations. :param do_rotation: :param angle_x: :param angle_y: :param angle_z: :param do_scale: :param scale: :param border_mode_data: :param border_cval_data: :param order_data: :param border_mode_seg: :param border_cval_seg: :param order_seg: :param random_crop: :param p_el_per_sample: :param p_scale_per_sample: :param p_rot_per_sample: :param clip_to_safe_magnitude: :return: """ dim = len(patch_size) seg_result = None if seg is not None: if dim == 2: seg_result = np.zeros( (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if dim == 2: data_result = np.zeros( (data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) modified_coords = False if np.random.uniform() < p_el_per_sample and do_elastic_deform: mag = [] sigmas = [] # one scale per case, scale is in percent of patch_size def_scale = np.random.uniform(deformation_scale[0], deformation_scale[1]) for d in range(len(data[sample_id].shape) - 1): # transform relative def_scale in pixels sigmas.append(def_scale * patch_size[d]) # define max magnitude and min_magnitude max_magnitude = sigmas[-1] * (1 / 2) min_magnitude = sigmas[-1] * (1 / 8) # the magnitude needs to depend on the scale, otherwise not much is going to happen most of the time. # we want the magnitude to be high, but not higher than max_magnitude (otherwise the deformations # become very ugly). Let's sample mag_real with a gaussian # mag_real = np.random.normal(max_magnitude * (2 / 3), scale=max_magnitude / 3) # clip to make sure we stay reasonable # mag_real = np.clip(mag_real, 0, max_magnitude) mag_real = np.random.uniform(min_magnitude, max_magnitude) mag.append(mag_real) # print(np.round(sigmas, decimals=3), np.round(mag, decimals=3)) coords = elastic_deform_coordinates_2(coords, sigmas, mag) modified_coords = True if do_rotation and np.random.uniform() < p_rot_per_sample: if np.random.uniform() <= p_rot_per_axis: a_x = np.random.uniform(angle_x[0], angle_x[1]) else: a_x = 0 if dim == 3: if np.random.uniform() <= p_rot_per_axis: a_y = np.random.uniform(angle_y[0], angle_y[1]) else: a_y = 0 if np.random.uniform() <= p_rot_per_axis: a_z = np.random.uniform(angle_z[0], angle_z[1]) else: a_z = 0 coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) modified_coords = True if do_scale and np.random.uniform() < p_scale_per_sample: if independent_scale_for_each_axis and np.random.uniform( ) < p_independent_scale_per_axis: sc = [] for _ in range(dim): if np.random.random() < 0.5 and scale[0] < 1: sc.append(np.random.uniform(scale[0], 1)) else: sc.append(np.random.uniform(max(scale[0], 1), scale[1])) else: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) coords = scale_coords(coords, sc) modified_coords = True # now find a nice center location if modified_coords: # recenter coordinates coords_mean = coords.mean(axis=tuple(range(1, len(coords.shape))), keepdims=True) coords -= coords_mean for d in range(dim): if random_crop: ctr = np.random.uniform( patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d]) else: ctr = int(np.round(data.shape[d + 2] / 2.)) coords[d] += ctr for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img( data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[1]): seg_result[sample_id, channel_id] = interpolate_img( seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) else: if seg is None: s = None else: s = seg[sample_id:sample_id + 1] if random_crop: margin = [ patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim) ] d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin) else: d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s) data_result[sample_id] = d[0] if seg is not None: seg_result[sample_id] = s[0] return data_result, seg_result
def augment_spatial_transform(data, seg=None, patch_size=None, patch_center_dist_from_border=None, do_elastic_deform=True, alpha=(0., 900.), sigma=(9., 13.), do_rotation=True, angle_x=default_3D_augmentation_params.get("rotation_x"), angle_y=default_3D_augmentation_params.get("rotation_y"), angle_z=default_3D_augmentation_params.get("rotation_z"), do_scale=True, scale=(0.85, 1.25), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=-1, order_seg=1, random_crop=False, p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1): if len(data.shape) == 3: print("Attention: data shape must be (C, D, H, W). I will transform it.") data = data[None] if seg is not None and len(seg.shape) == 3: seg = seg[None] if patch_size is None: patch_size = (data.shape[1], data.shape[2], data.shape[3]) dim = len(patch_size) seg_result = None if seg is not None: if dim == 2: seg_result = np.zeros((seg.shape[0], patch_size[0], patch_size[1]), dtype=np.float32) else: seg_result = np.zeros((seg.shape[0], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if dim == 2: data_result = np.zeros((data.shape[0], patch_size[0], patch_size[1]), dtype=np.float32) else: data_result = np.zeros((data.shape[0], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] coords = create_zero_centered_coordinate_mesh(patch_size) modified_coords = False if np.random.uniform() < p_el_per_sample and do_elastic_deform: a = np.random.uniform(alpha[0], alpha[1]) s = np.random.uniform(sigma[0], sigma[1]) coords = elastic_deform_coordinates(coords, a, s) modified_coords = True if np.random.uniform() < p_rot_per_sample and do_rotation: if angle_x[0] == angle_x[1]: a_x = angle_x[0] else: a_x = np.random.uniform(angle_x[0], angle_x[1]) if dim == 3: if angle_y[0] == angle_y[1]: a_y = angle_y[0] else: a_y = np.random.uniform(angle_y[0], angle_y[1]) if angle_z[0] == angle_z[1]: a_z = angle_z[0] else: a_z = np.random.uniform(angle_z[0], angle_z[1]) coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) modified_coords = True if np.random.uniform() < p_scale_per_sample and do_scale: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) coords = scale_coords(coords, sc) modified_coords = True # now find a nice center location if modified_coords: for d in range(dim): if random_crop: ctr = np.random.uniform(patch_center_dist_from_border[d], data.shape[d + 1] - patch_center_dist_from_border[d]) else: ctr = int(np.around(data.shape[d + 1] / 2.)) coords[d] += ctr for channel_id in range(data.shape[0]): data_result[channel_id] = interpolate_img(data[channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[0]): seg_result[channel_id] = interpolate_img(seg[channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) else: if seg is None: s = None else: s = seg[None] if random_crop: margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)] d, s = random_crop_aug(data[None], s, patch_size, margin) else: d, s = center_crop_aug(data[None], patch_size, s) data_result = d[0] if seg is not None: seg_result = s[0] return data_result, seg_result
def cc_augment(config_task, data, seg, patch_type, patch_size, patch_center_dist_from_border=30, do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1, tag=''): # patch_center_dist_from_border should be no more than 1/2 patch size. otherwise code not available. # data: [n,c,d,h,w] # seg: [n,c,d,h,w] dim = len(patch_size) seg_result = None if seg is not None: seg_result = np.zeros([seg.shape[0], seg.shape[1]] + patch_size, dtype=np.float32) data_result = np.zeros([data.shape[0], data.shape[1]] + patch_size, dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] ## for-loop for dim[0] augs = list() for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) # now find a nice center location and extract patch if seg is None: patch_type = 'any' handler = 0 n = 0 while handler == 0: # augmentation modified_coords = False if np.random.uniform() < p_el_per_sample and do_elastic_deform: a = np.random.uniform(alpha[0], alpha[1]) s = np.random.uniform(sigma[0], sigma[1]) coords = elastic_deform_coordinates(coords, a, s) modified_coords = True augs.append('elastic') if np.random.uniform() < p_rot_per_sample and do_rotation: if angle_x[0] == angle_x[1]: a_x = angle_x[0] else: a_x = np.random.uniform(angle_x[0], angle_x[1]) if dim == 3: if angle_y[0] == angle_y[1]: a_y = angle_y[0] else: a_y = np.random.uniform(angle_y[0], angle_y[1]) if angle_z[0] == angle_z[1]: a_z = angle_z[0] else: a_z = np.random.uniform(angle_z[0], angle_z[1]) coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) modified_coords = True augs.append('rotation') if np.random.uniform() < p_scale_per_sample and do_scale: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) coords = scale_coords(coords, sc) modified_coords = True augs.append('scale') # find candidate area for center, the area is cand_point_coord +/- patch_size if patch_type in ['fore', 'small'] and seg is not None: if seg.shape[1] > 1: logger.error('TBD for seg with multiple channels') if patch_type == 'fore': lab_coords = np.where( seg[sample_id, 0, ...] > 0) # lab_coords: tuple elif patch_type == 'small': if config_task.task == 'Task05_Prostate': lab_coords = np.where(seg[sample_id, 0, ...] == 1) else: lab_coords = np.where( seg[sample_id, 0, ...] == config_task.num_class - 1) if len(lab_coords[0]) > 0: # 0 means no such label exists idx = np.random.choice(len(lab_coords[0])) cand_point_coord = [ coords[idx] for coords in lab_coords ] # coords for one random point from 'fore' ground else: cand_point_coord = None if patch_type in ['fore', 'small'] and cand_point_coord is None: ctr_list = None handler = 1 data_result = None seg_result = None augs = None else: ctr_list = list() # coords of the patch center for d in range(dim): if random_crop: if patch_type in ['fore', 'small'] and seg is not None: low = max( patch_center_dist_from_border[d] - 1, cand_point_coord[d] - (patch_size[d] / 2 - 1)) low = int(low) upper = min( cand_point_coord[d] + (patch_size[d] / 2 - 1), data.shape[d + 2] - (patch_center_dist_from_border[d] - 1) ) # +/- patch_size[d] is better but computation costly upper = int(upper) if low == upper: ctr = int(low) elif low < upper: ctr = int(np.random.randint(low, upper)) # if n > 1: # logger.info('n:{}; [low,upper]:{}, ctr:{}'.format(n, str([low, upper]), ctr)) else: logger.error( '(low:{} should be <= upper:{}). patch_type:{}, patch_center_dist_from_border:{}, cand_point_coord:{}, cand point seg value:{}, data.shape:{}, ctr_list:{}' .format( low, upper, str(patch_type), str(patch_center_dist_from_border), str(cand_point_coord), seg[sample_id, 0] + cand_point_coord, str(data.shape), str(ctr_list))) elif patch_type == 'any': if patch_center_dist_from_border[d] == data.shape[ d + 2] - patch_center_dist_from_border[d]: ctr = int(patch_center_dist_from_border[d]) elif patch_center_dist_from_border[d] < data.shape[ d + 2] - patch_center_dist_from_border[d]: ctr = int( np.random.randint( patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d])) else: logger.error( 'low should be <= upper. patch_type:{}, patch_center_dist_from_border:{}, data.shape:{}, ctr_list:{}' .format(str(patch_type), str(patch_center_dist_from_border), str(data.shape), str(ctr_list))) else: # center crop ctr = int(np.round(data.shape[d + 2] / 2.)) ctr_list.append(ctr) # extracting patch if n < 10 and modified_coords: for d in range(dim): coords[d] += ctr_list[d] for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img( data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[1]): seg_result[sample_id, channel_id] = interpolate_img( seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) else: augs = list() if seg is None: s = None else: s = seg[sample_id:sample_id + 1] if random_crop: # margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)] # d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin) d_tmps = list() for channel_id in range(data.shape[1]): d_tmp = utils.extract_roi_from_volume( data[sample_id, channel_id], ctr_list, patch_size, fill="zero") d_tmps.append(d_tmp) d = np.asarray(d_tmps) if seg is not None: s_tmps = list() for channel_id in range(seg.shape[1]): s_tmp = utils.extract_roi_from_volume( seg[sample_id, channel_id], ctr_list, patch_size, fill="zero") s_tmps.append(s_tmp) s = np.asarray(s_tmps) else: d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s) # data_result[sample_id] = d[0] data_result[sample_id] = d if seg is not None: # seg_result[sample_id] = s[0] seg_result[sample_id] = s ## check patch if patch_type in [ 'fore' ]: # cancer could be very very small. so use opproximate method (i.e. use 'fore'). if np.any(seg_result > 0) and np.any(data_result != 0): handler = 1 else: handler = 0 elif patch_type in ['small']: if config_task.task == 'Task05_Prostate': if np.any(seg_result == 1) and np.any( data_result != 0): handler = 1 else: handler = 0 else: if np.any(seg_result == config_task.num_class - 1) and np.any(data_result != 0): handler = 1 else: handler = 0 else: if np.any(data_result != 0): handler = 1 else: handler = 0 n += 1 if n > 5: logger.info( 'tag:{}, patch_type: {}; handler: {}; times: {}; cand point:{}; cand point seg value:{}; ctr_list:{}; data.shape:{}; np.unique(seg_result):{}; np.sum(data_result):{}' .format( tag, patch_type, handler, n, str(cand_point_coord), seg[sample_id, 0, cand_point_coord[0], cand_point_coord[1], cand_point_coord[2]], str(ctr_list), str(data.shape), np.unique(seg_result, return_counts=True), np.sum(data_result))) return data_result, seg_result, augs
def affine_transformation(vol, radius, translate, scale, bspline_order, border_mode, constant_val, is_reverse): """ forward: scale -> rotation -> translation backward: translation -> rotation -> scale :param vol: np.ndarray or torch.tensor(batch, channel, h, w, d) :param radius: tuple :param translate: tuple :param scale: tuple :param bspline_order: 0~5 0:nearest, 1:bilinear, 2~5 bspline, 3 is common used. :param border_mode: "constant", "nearest" :param constant_val: contant value :param is_reverse: bool :return: """ tensor_flag = False _device = torch.device("cpu") if isinstance(vol, torch.Tensor): _device = vol.device vol = vol.cpu().detach().numpy() tensor_flag = True shape = vol.shape[2:] dim = len(vol.shape) - 2 center = tuple((i - 1) / 2. for i in shape) tmp = tuple([np.arange(i) for i in shape]) coords = np.array(np.meshgrid(*tmp, indexing='ij')).astype(float) for d in range(len(shape)): coords[d] -= center[d] # centered coords if is_reverse: # translation first for i in range(dim): coords[i] -= translate[i] # rotation rot_matrix = np.identity(len(coords)) rot_matrix = create_matrix_rotation_z_3d(radius[2], rot_matrix) rot_matrix = create_matrix_rotation_y_3d(radius[1], rot_matrix) rot_matrix = create_matrix_rotation_x_3d(radius[0], rot_matrix) coords = np.dot( coords.reshape(len(coords), -1).transpose(), rot_matrix).transpose().reshape(coords.shape) # scale if isinstance(scale, (tuple, list, np.ndarray)): # scale axis individual assert len(scale) == len(coords) for i in range(len(scale)): coords[i] *= scale[i] else: coords *= scale # scale axis both else: # scale first if isinstance(scale, (tuple, list, np.ndarray)): # scale axis individual assert len(scale) == len(coords) for i in range(len(scale)): coords[i] *= scale[i] else: coords *= scale # scale axis both # rotation rot_matrix = np.identity(len(coords)) rot_matrix = create_matrix_rotation_x_3d(radius[0], rot_matrix) rot_matrix = create_matrix_rotation_y_3d(radius[1], rot_matrix) rot_matrix = create_matrix_rotation_z_3d(radius[2], rot_matrix) coords = np.dot( coords.reshape(len(coords), -1).transpose(), rot_matrix).transpose().reshape(coords.shape) # translate for i in range(dim): coords[i] -= translate[i] for i in range(dim): coords[i] += center[i] data_result = np.zeros_like(vol, dtype=np.float32) for sample_id in range(vol.shape[0]): for channel_id in range(vol.shape[1]): data_result[sample_id, channel_id] = interpolate_img(vol[sample_id, channel_id], coords, order=bspline_order, mode=border_mode, cval=constant_val) if tensor_flag: data_result = torch.from_numpy(data_result) data_result = data_result.to(_device) return data_result
def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True): dim = len(patch_size) seg_result = None if seg is not None: if dim == 2: seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if dim == 2: data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) if do_elastic_deform: a = np.random.uniform(alpha[0], alpha[1]) s = np.random.uniform(sigma[0], sigma[1]) coords = elastic_deform_coordinates(coords, a, s) if do_rotation: if angle_x[0] == angle_x[1]: a_x = angle_x[0] else: a_x = np.random.uniform(angle_x[0], angle_x[1]) if dim == 3: if angle_y[0] == angle_y[1]: a_y = angle_y[0] else: a_y = np.random.uniform(angle_y[0], angle_y[1]) if angle_z[0] == angle_z[1]: a_z = angle_z[0] else: a_z = np.random.uniform(angle_z[0], angle_z[1]) coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) if do_scale: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) coords = scale_coords(coords, sc) # now find a nice center location for d in range(dim): if random_crop: ctr = np.random.uniform(patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d]) else: ctr = int(np.round(data.shape[d + 2] / 2.)) coords[d] += ctr for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[1]): seg_result[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) return data_result, seg_result
def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1, independent_scale_for_each_axis=False, p_rot_per_axis: float = 1): dim = len(patch_size) seg_result = None if seg is not None: if dim == 2: seg_result = np.zeros( (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if dim == 2: data_result = np.zeros( (data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) modified_coords = False if do_elastic_deform and np.random.uniform() < p_el_per_sample: a = np.random.uniform(alpha[0], alpha[1]) s = np.random.uniform(sigma[0], sigma[1]) coords = elastic_deform_coordinates(coords, a, s) modified_coords = True if do_rotation and np.random.uniform() < p_rot_per_sample: if np.random.uniform() <= p_rot_per_axis: a_x = np.random.uniform(angle_x[0], angle_x[1]) else: a_x = 0 if dim == 3: if np.random.uniform() <= p_rot_per_axis: a_y = np.random.uniform(angle_y[0], angle_y[1]) else: a_y = 0 if np.random.uniform() <= p_rot_per_axis: a_z = np.random.uniform(angle_z[0], angle_z[1]) else: a_z = 0 coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) modified_coords = True if do_scale and np.random.uniform() < p_scale_per_sample: if not independent_scale_for_each_axis: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) else: sc = [] for _ in range(dim): if np.random.random() < 0.5 and scale[0] < 1: sc.append(np.random.uniform(scale[0], 1)) else: sc.append(np.random.uniform(max(scale[0], 1), scale[1])) coords = scale_coords(coords, sc) modified_coords = True # now find a nice center location if modified_coords: for d in range(dim): if random_crop: ctr = np.random.uniform( patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d]) else: ctr = int(np.round(data.shape[d + 2] / 2.)) coords[d] += ctr for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img( data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[1]): seg_result[sample_id, channel_id] = interpolate_img( seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) else: # valid values: constant/edge/median/minimum/reflect/symmetric/etc if border_mode_data == 'nearest': border_pad_mode_data = 'edge' pad_kwargs_data = {} elif border_mode_data == 'constant': border_pad_mode_data = border_mode_data pad_kwargs_data = {'constant_values': border_cval_data} if border_mode_seg == 'nearest': border_pad_mode_seg = 'edge' pad_kwargs_seg = {} elif border_mode_seg == 'constant': border_pad_mode_seg = border_mode_seg pad_kwargs_seg = {'constant_values': border_cval_seg} if seg is None: s = None else: s = seg[sample_id:sample_id + 1] if random_crop: margin = [ patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim) ] d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin, pad_mode=border_pad_mode_data, pad_kwargs=pad_kwargs_data, pad_mode_seg=border_pad_mode_seg, pad_kwargs_seg=pad_kwargs_seg) else: d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s, pad_mode=border_pad_mode_data, pad_kwargs=pad_kwargs_data, pad_mode_seg=border_pad_mode_seg, pad_kwargs_seg=pad_kwargs_seg) data_result[sample_id] = d[0] if seg is not None: seg_result[sample_id] = s[0] return data_result, seg_result
def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, crop_mode='random', p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1): dim = len(patch_size) seg_result = None if seg is not None: if dim == 2: seg_result = np.zeros( (seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if dim == 2: data_result = np.zeros( (data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) modified_coords = False if np.random.uniform() < p_el_per_sample and do_elastic_deform: a = np.random.uniform(alpha[0], alpha[1]) s = np.random.uniform(sigma[0], sigma[1]) coords = elastic_deform_coordinates(coords, a, s) modified_coords = True if np.random.uniform() < p_rot_per_sample and do_rotation: if angle_x[0] == angle_x[1]: a_x = angle_x[0] else: a_x = np.random.uniform(angle_x[0], angle_x[1]) if dim == 3: if angle_y[0] == angle_y[1]: a_y = angle_y[0] else: a_y = np.random.uniform(angle_y[0], angle_y[1]) if angle_z[0] == angle_z[1]: a_z = angle_z[0] else: a_z = np.random.uniform(angle_z[0], angle_z[1]) coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) modified_coords = True if np.random.uniform() < p_scale_per_sample and do_scale: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) coords = scale_coords(coords, sc) modified_coords = True # now find a nice center location if modified_coords: if crop_mode == 'roi': lbs = get_lbls_for_roi_crop(patch_size, data.shape, seg[sample_id, 0]) for d in range(dim): if crop_mode == 'roi': ctr = int(lbs[d] + np.round(patch_size[d] / 2.)) elif crop_mode == 'center': ctr = int(np.round(data.shape[d + 2] / 2.)) else: ctr = np.random.uniform( patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d]) coords[d] += ctr for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img( data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[1]): seg_result[sample_id, channel_id] = interpolate_img( seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) else: if seg is None: s = None else: s = seg[sample_id:sample_id + 1] if crop_mode == 'roi': d, s = crop(data[sample_id:sample_id + 1], s, patch_size, crop_type='roi') elif crop_mode == 'center': d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s) else: #crop_mode == 'random' margin = [ patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim) ] d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin) data_result[sample_id] = d[0] if seg is not None: seg_result[sample_id] = s[0] return data_result, seg_result
def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=1, border_mode_seg='constant', border_cval_seg=0, order_seg=0, fraction=0.0, spacing=(1,1,1)): dim = len(patch_size) seg_result = None while len(data.shape) <5: data = data[np.newaxis,:] while len(seg.shape) <5: seg = seg[np.newaxis, :] if seg is not None: if dim == 2: seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if dim == 2: data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) else: data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) if do_elastic_deform: a_base = np.repeat(np.random.uniform(alpha[0], alpha[1]),3) a = a_base / np.array(spacing) s = np.repeat(np.random.uniform(sigma[0], sigma[1]), 3) coords = elastic_deform_coordinates(coords, a, s) if do_rotation: if angle_x[0] == angle_x[1]: a_x = angle_x[0] else: a_x = np.random.uniform(angle_x[0], angle_x[1]) if dim == 3: if angle_y[0] == angle_y[1]: a_y = angle_y[0] else: a_y = np.random.uniform(angle_y[0], angle_y[1]) if angle_z[0] == angle_z[1]: a_z = angle_z[0] else: a_z = np.random.uniform(angle_z[0], angle_z[1]) coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) if do_scale: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) coords = scale_coords(coords, sc) # now find a nice center location ctr_obj = np.where(seg[0,1:] > 0)[1:] ctr_obj = np.array([np.mean(i) for i in ctr_obj]) ctr_obj = np.random.uniform(0.8, 1.2) * ctr_obj ctr_obj = np.clip(ctr_obj, patch_size, data.shape[-3:] - patch_size) rval = np.random.uniform(0.0,1.0) for d in range(dim): if rval > fraction: ctr = np.random.uniform(patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d]) #elif crop > 'center': # ctr = int(np.round(data.shape[d + 2] / 2.)) elif rval < fraction: ctr = ctr_obj[d] coords[d] += ctr for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[1]): seg_result[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg) return data_result[0], seg_result[0], rval