def generate_train_batch(self): # everything done in here is per batch # print statements in here get confusing due to multithreading batch_pids = self.get_batch_pids() batch_data, batch_segs, batch_patient_targets = [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count and empty count of classes in batch # empty count for no presence of resp. class in whole sample (empty slices in 2D/patients in 3D) batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32') batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32') for b in range(len(batch_pids)): patient = self._data[batch_pids[b]] all_data = np.load(patient['data'], mmap_mode='r') data = all_data[0].astype('float16')[np.newaxis] seg = all_data[1].astype('uint8') spatial_shp = data[0].shape assert spatial_shp == seg.shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] data = dutils.pad_nd_image(data, (len(data), *new_shape)) seg = dutils.pad_nd_image(seg, new_shape) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per observable batch_roi_items[o].append(patient[o]) for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero == 0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'), 'pid': batch_pids, 'roi_counts': batch_roi_counts, 'empty_counts': batch_empty_counts} for key,val in batch_roi_items.items(): #extend batch dic by entries of observables dic batch[key] = np.array(val) return batch
def generate_train_batch(self): pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] # data shape: from (c, z, y, x) to (c, y, x, z). data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(3, 1, 2, 0)).copy() seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(3, 1, 2, 0))[0].copy() batch_class_targets = np.array([patient['class_target']]) # pad data if smaller than patch_size seen during training. if np.any([ data.shape[dim + 1] < ps for dim, ps in enumerate(self.patch_size) ]): new_shape = [data.shape[0]] + [ np.max([data.shape[dim + 1], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size) ] data = dutils.pad_nd_image( data, new_shape ) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor. if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: out_data = data[np.newaxis] out_seg = seg[np.newaxis, np.newaxis] out_targets = batch_class_targets batch_3D = { 'data': out_data, 'seg': out_seg, 'class_target': out_targets, 'pid': pid } converter = ConvertSegToBoundingBoxCoordinates( dim=3, get_rois_from_seg_flag=False, class_specific_seg_flag=self.cf.class_specific_seg_flag) batch_3D = converter(**batch_3D) batch_3D.update({ 'patient_bb_target': batch_3D['bb_target'], 'patient_roi_labels': batch_3D['class_target'], 'original_img_shape': out_data.shape }) if self.cf.dim == 2: out_data = np.transpose(data, axes=(3, 0, 1, 2)) # (z, c, y, x ) out_seg = np.transpose(seg, axes=(2, 0, 1))[:, np.newaxis] out_targets = np.array( np.repeat(batch_class_targets, out_data.shape[0], axis=0)) # if set to not None, add neighbouring slices to each selected slice in channel dimension. if self.cf.n_3D_context is not None: slice_range = range(self.cf.n_3D_context, out_data.shape[0] + self.cf.n_3D_context) out_data = np.pad( out_data, ((self.cf.n_3D_context, self.cf.n_3D_context), (0, 0), (0, 0), (0, 0)), 'constant', constant_values=0) out_data = np.array([ np.concatenate([ out_data[ii] for ii in range(slice_id - self.cf.n_3D_context, slice_id + self.cf.n_3D_context + 1) ], axis=0) for slice_id in slice_range ]) batch_2D = { 'data': out_data, 'seg': out_seg, 'class_target': out_targets, 'pid': pid } converter = ConvertSegToBoundingBoxCoordinates( dim=2, get_rois_from_seg_flag=False, class_specific_seg_flag=self.cf.class_specific_seg_flag) batch_2D = converter(**batch_2D) if self.cf.merge_2D_to_3D_preds: batch_2D.update({ 'patient_bb_target': batch_3D['patient_bb_target'], 'patient_roi_labels': batch_3D['patient_roi_labels'], 'original_img_shape': out_data.shape }) else: batch_2D.update({ 'patient_bb_target': batch_2D['bb_target'], 'patient_roi_labels': batch_2D['class_target'], 'original_img_shape': out_data.shape }) out_batch = batch_3D if self.cf.dim == 3 else batch_2D patient_batch = out_batch # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension. # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1. if np.any( [data.shape[dim + 1] > self.patch_size[dim] for dim in range(3)]): patch_crop_coords_list = dutils.get_patch_crop_coords( data[0], self.patch_size) new_img_batch, new_seg_batch, new_class_targets_batch = [], [], [] for cix, c in enumerate(patch_crop_coords_list): seg_patch = seg[c[0]:c[1], c[2]:c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) # if set to not None, add neighbouring slices to each selected slice in channel dimension. # correct patch_crop coordinates by added slices of 3D context. if self.cf.dim == 2 and self.cf.n_3D_context is not None: tmp_c_5 = c[5] + (self.cf.n_3D_context * 2) if cix == 0: data = np.pad( data, ((0, 0), (0, 0), (0, 0), (self.cf.n_3D_context, self.cf.n_3D_context)), 'constant', constant_values=0) else: tmp_c_5 = c[5] new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5]) data = np.array(new_img_batch) # (n_patches, c, x, y, z) seg = np.array( new_seg_batch)[:, np.newaxis] # (n_patches, 1, x, y, z) batch_class_targets = np.repeat(batch_class_targets, len(patch_crop_coords_list), axis=0) if self.cf.dim == 2: if self.cf.n_3D_context is not None: data = np.transpose(data[:, 0], axes=(0, 3, 1, 2)) else: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = { 'data': data, 'seg': seg, 'class_target': batch_class_targets, 'pid': pid } patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch[ 'patient_bb_target'] patch_batch['patient_roi_labels'] = patient_batch[ 'patient_roi_labels'] patch_batch['original_img_shape'] = patient_batch[ 'original_img_shape'] converter = ConvertSegToBoundingBoxCoordinates( self.cf.dim, get_rois_from_seg_flag=False, class_specific_seg_flag=self.cf.class_specific_seg_flag) patch_batch = converter(**patch_batch) out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 out_batch['data'][:, self.cf.drop_channels_test, ] = 0. return out_batch
def generate_train_batch(self): batch_data, batch_segs, batch_pids, batch_targets, batch_patient_labels = [], [], [], [], [] class_targets_list = [ v['class_target'] for (k, v) in self._data.items() ] #I am turning this off, because it is problematic with my class 20 if False: #self.cf.head_classes > 2: # samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio "batch_sample_slack). batch_ixs = dutils.get_class_balanced_patients( class_targets_list, self.batch_size, self.cf.head_classes - 1, slack_factor=self.cf.batch_sample_slack) else: batch_ixs = np.random.choice(len(class_targets_list), self.batch_size) patients = list(self._data.items()) for b in batch_ixs: patient = patients[b][1] # data shape: from (c, z, y, x) to (c, y, x, z). data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(3, 1, 2, 0)) seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(3, 1, 2, 0)) batch_pids.append(patient['pid']) batch_targets.append(patient['class_target']) if self.cf.dim == 2: # draw random slice from patient while oversampling slices containing foreground objects with p_fg. if len(patient['fg_slices']) > 0: fg_prob = self.p_fg / len(patient['fg_slices']) bg_prob = (1 - self.p_fg) / (data.shape[3] - len(patient['fg_slices'])) slices_prob = [ fg_prob if ix in patient['fg_slices'] else bg_prob for ix in range(data.shape[3]) ] slice_id = np.random.choice(data.shape[3], p=slices_prob) else: slice_id = np.random.choice(data.shape[3]) # if set to not None, add neighbouring slices to each selected slice in channel dimension. if self.cf.n_3D_context is not None: padded_data = dutils.pad_nd_image( data[0], [(data.shape[-1] + (self.cf.n_3D_context * 2))], mode='constant') padded_slice_id = slice_id + self.cf.n_3D_context data = (np.concatenate([ padded_data[..., ii][np.newaxis] for ii in range( padded_slice_id - self.cf.n_3D_context, padded_slice_id + self.cf.n_3D_context + 1) ], axis=0)) else: data = data[..., slice_id] seg = seg[..., slice_id] # pad data if smaller than pre_crop_size. if np.any([ data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size) ]): new_shape = [ np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size) ] data = dutils.pad_nd_image(data, new_shape, mode='constant') seg = dutils.pad_nd_image(seg, new_shape, mode='constant') # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg. crop_dims = [ dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps ] if len(crop_dims) > 0: fg_prob_sample = np.random.rand(1) # with p_fg: sample random pixel from random ROI and shift center by random value. if fg_prob_sample < self.p_fg and np.sum(seg) > 0: seg_ixs = np.argwhere( seg == np.random.choice(np.unique(seg)[1:], 1)) roi_anchor_pixel = seg_ixs[np.random.choice( seg_ixs.shape[0], 1)][0] assert seg[tuple(roi_anchor_pixel)] > 0 # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by # distance to the desired ROI < patch_size /2. # (here final patch size to account for center_crop after data augmentation). sample_seg_center = {} for ii in crop_dims: low = np.max((self.cf.pre_crop_size[ii] // 2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii] // 2 - self.crop_margin[ii]))) high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii] // 2, roi_anchor_pixel[ii] + (self.cf.patch_size[ii] // 2 - self.crop_margin[ii]))) # happens if lesion on the edge of the image. dont care about roi anymore, # just make sure pre-crop is inside image. if low >= high: low = data.shape[ii + 1] // 2 - ( data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) high = data.shape[ii + 1] // 2 + ( data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) sample_seg_center[ii] = np.random.randint(low=low, high=high) else: # not guaranteed to be empty. probability of emptiness depends on the data. sample_seg_center = { ii: np.random.randint(low=self.cf.pre_crop_size[ii] // 2, high=data.shape[ii + 1] - self.cf.pre_crop_size[ii] // 2) for ii in crop_dims } for ii in crop_dims: min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2) max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2) data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1) seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii) batch_data.append(data) batch_segs.append(seg) data = np.array(batch_data) seg = np.array(batch_segs).astype(np.uint8) class_target = np.array(batch_targets) return { 'data': data, 'seg': seg, 'pid': batch_pids, 'class_target': class_target }
def generate_train_batch(self): pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] # (c, y, x, z) seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) print('patient', patient) print('data', data.shape) batch_class_targets = np.array([patient['class_target']]) # pad data if smaller than patch_size seen during training. if np.any([ data.shape[dim + 1] < ps for dim, ps in enumerate(self.patch_size) ]): new_shape = [data.shape[0]] + [ np.max([data.shape[dim + 1], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size) ] data = dutils.pad_nd_image( data, new_shape ) # use 'return_slicer' to crop image back to original shape. if len(new_shape) == 4: new_shape = new_shape[1:] seg = dutils.pad_nd_image(seg, new_shape) # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor. if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: #default True out_data = data[np.newaxis] out_seg = seg[np.newaxis, np.newaxis] out_targets = batch_class_targets batch_3D = { 'data': out_data, 'seg': out_seg, 'class_target': out_targets, 'pid': pid } converter = ConvertSegToBoundingBoxCoordinates( dim=3, get_rois_from_seg_flag=False, class_specific_seg_flag=False) #default false batch_3D = converter(**batch_3D) batch_3D.update({ 'patient_bb_target': batch_3D['bb_target'], 'patient_roi_labels': batch_3D['roi_labels'], 'original_img_shape': out_data.shape }) out_batch = batch_3D if self.cf.dim == 3 else batch_2D patient_batch = out_batch # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension. # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1. if np.any( [data.shape[dim + 1] > self.patch_size[dim] for dim in range(3)]): patch_crop_coords_list = dutils.get_patch_crop_coords_stride( data[0], self.patch_size, self.testing_patch_stride) new_img_batch, new_seg_batch, new_class_targets_batch = [], [], [] for cix, c in enumerate(patch_crop_coords_list): seg_patch = seg[c[0]:c[1], c[2]:c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) # if set to not None, add neighbouring slices to each selected slice in channel dimension. # correct patch_crop coordinates by added slices of 3D context. if self.cf.dim == 2 and self.cf.n_3D_context is not None: tmp_c_5 = c[5] + (self.cf.n_3D_context * 2) if cix == 0: data = np.pad( data, ((0, 0), (0, 0), (0, 0), (self.cf.n_3D_context, self.cf.n_3D_context)), 'constant', constant_values=0) else: tmp_c_5 = c[5] new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5]) data = np.array(new_img_batch) # (n_patches, c, x, y, z) seg = np.array( new_seg_batch)[:, np.newaxis] # (n_patches, 1, x, y, z) batch_class_targets = np.repeat(batch_class_targets, len(patch_crop_coords_list), axis=0) patch_batch = { 'data': data, 'seg': seg, 'class_target': batch_class_targets, 'pid': pid } #classtarget is len == cropsize patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch[ 'patient_bb_target'] #gt box patch_batch['patient_roi_labels'] = patient_batch[ 'patient_roi_labels'] patch_batch['original_img_shape'] = patient_batch[ 'original_img_shape'] converter = ConvertSegToBoundingBoxCoordinates( self.cf.dim, get_rois_from_seg_flag=False, class_specific_seg_flag=self.cf.class_specific_seg_flag) patch_batch = converter(**patch_batch) out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 if out_batch['patient_roi_labels'][0][0] > 0: out_batch['patient_roi_labels'][0] = [1] return out_batch
def generate_train_batch(self): #print(' --- start generate train batch ---') batch_data, batch_segs, batch_pids, batch_targets, batch_patient_labels = [], [], [], [], [] class_targets_list = [ v['class_target'] for (k, v) in self._data.items() ] #print('class_targets_list',np.array(class_targets_list)) #print('head_classes',self.cf.head_classes) if self.cf.head_classes > 2: # samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio "batch_sample_slack). batch_ixs = dutils.get_class_balanced_patients( class_targets_list, self.batch_size, self.cf.head_classes - 1, slack_factor=self.cf.batch_sample_slack) #0.2 else: batch_ixs = np.random.choice(len(class_targets_list), self.batch_size) #print('batch_idx in generator: ', batch_ids) patients = list(self._data.items()) #print('len(patients): ', len(patients)) for b in batch_ixs: patient = patients[b][1] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] # (c, y, x, z) seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) batch_pids.append(patient['pid']) batch_targets.append(patient['class_target']) # pad data if smaller than pre_crop_size. if np.any([ data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size) ]): #print(patient['pid']) new_shape = [ np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size) ] #print('new_shape',new_shape) data = dutils.pad_nd_image(data, new_shape, mode='constant') seg = dutils.pad_nd_image(seg, new_shape, mode='constant') # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg. crop_dims = [ dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps ] if len(crop_dims) > 0: fg_prob_sample = np.random.rand(1) # with p_fg(0.5): sample random pixel from random ROI and shift center by random value. if fg_prob_sample < self.p_fg and np.sum(seg) > 0: #_ = np.unique(seg)[1:] #print('unique seg',_) seg_ixs = np.argwhere(seg == np.random.choice( np.unique(seg)[1:], 1)) #location of segmap == 1 roi_anchor_pixel = seg_ixs[np.random.choice( seg_ixs.shape[0], 1)][0] assert seg[tuple(roi_anchor_pixel)] > 0 # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by # distance to the desired ROI < patch_size /2. # (here final patch size to account for center_crop after data augmentation). sample_seg_center = {} for ii in crop_dims: low = np.max((self.cf.pre_crop_size[ii] // 2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii] // 2 - self.crop_margin[ii]))) high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii] // 2, roi_anchor_pixel[ii] + (self.cf.patch_size[ii] // 2 - self.crop_margin[ii]))) # happens if lesion on the edge of the image. dont care about roi anymore, # just make sure pre-crop is inside image. if low >= high: low = data.shape[ii + 1] // 2 - ( data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) high = data.shape[ii + 1] // 2 + ( data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) sample_seg_center[ii] = np.random.randint(low=low, high=high) else: # not guaranteed to be empty. probability of emptiness depends on the data. sample_seg_center = { ii: np.random.randint(low=self.cf.pre_crop_size[ii] // 2, high=data.shape[ii + 1] - self.cf.pre_crop_size[ii] // 2) for ii in crop_dims } for ii in crop_dims: min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2) max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2) data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1) seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii) batch_data.append(data) batch_segs.append(seg[np.newaxis]) data = np.array(batch_data) seg = np.array(batch_segs).astype(np.uint8) class_target = np.array(batch_targets) return { 'data': data, 'seg': seg, 'pid': batch_pids, 'class_target': class_target }
def generate_train_batch(self, pid=None): if pid is None: pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis] seg = np.load(patient[self.gt_prefix+'seg']).astype('uint8')[np.newaxis] data_shp_raw = data.shape plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None data = data[self.chans] discarded_chans = len( [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan]) spatial_shp = data[0].shape # spatial dims need to be in order x,y,z assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]): new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) if plot_bg is not None: plot_bg = dutils.pad_nd_image(plot_bg, new_shape) if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: # adds the batch dim here bc won't go through MTaugmenter out_data = data[np.newaxis] out_seg = seg[np.newaxis] if plot_bg is not None: out_plot_bg = plot_bg[np.newaxis] # data and seg shape: (1,c,x,y,z), where c=1 for seg batch_3D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_3D[o] = np.array([patient[self.gt_prefix+o]]) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(3, 0, 1, 2)).astype('float32') # (c,y,x,z) to (b=z,c,x,y), use z=b as batchdim out_seg = np.transpose(seg, axes=(3, 0, 1, 2)).astype('uint8') # (c,y,x,z) to (b=z,c,x,y) batch_2D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_2D[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(out_data), axis=0) converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) batch_2D = converter(**batch_2D) if plot_bg is not None: out_plot_bg = np.transpose(plot_bg, axes=(2, 0, 1)).astype('float32') if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D[o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * len(out_data))}) if self.cf.plot_bg_chan in self.chans and discarded_chans > 0: # len(self.chans[:self.cf.plot_bg_chan])<data_shp_raw[0]: assert plot_bg is None plot_bg = int(self.cf.plot_bg_chan - discarded_chans) out_plot_bg = plot_bg if plot_bg is not None: out_batch['plot_bg'] = out_plot_bg # eventual tiling into patches spatial_shp = out_batch["data"].shape[2:] if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]): patient_batch = out_batch print("patientiterator produced patched batch!") patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size) new_img_batch, new_seg_batch = [], [] for c in patch_crop_coords_list: new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]]) seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) shps = [] for arr in new_img_batch: shps.append(arr.shape) data = np.array(new_img_batch) # (patches, c, x, y, z) seg = np.array(new_seg_batch) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'), 'pid': np.array([patient['pid']] * data.shape[0])} for o in self.cf.roi_items: patch_batch[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(patch_crop_coords_list), axis=0) #patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_"+o] = patient_batch["patient_"+o] if self.cf.dim == 2: # this could also be named "unpatched_2d_roi_items" patch_batch["patient_" + o + "_2d"] = patient_batch[o] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] if self.cf.dim == 2: patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] if plot_bg is not None: patch_batch['patient_plot_bg'] = patient_batch['plot_bg'] converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, get_rois_from_seg=False, class_specific_seg=self.cf.class_specific_seg) patch_batch = converter(**patch_batch) out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch
def generate_train_batch(self): # everything done in here is per batch # print statements in here get confusing due to multithreading batch_pids = self.get_batch_pids() batch_data, batch_segs, batch_patient_targets = [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count and empty count of classes in batch # empty count for no presence of resp. class in whole sample (empty slices in 2D/patients in 3D) batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32') batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32') for b in range(len(batch_pids)): patient = self._data[batch_pids[b]] data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis] seg = np.load(patient['seg'], mmap_mode='r').astype('uint8') (c, y, x, z) = data.shape if self.cf.dim == 2: elig_slices, choose_fg = [], False if len(patient['fg_slices']) > 0: if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or np.random.rand( 1) <= self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices']) if len(elig_slices) > 0: sl_pick_ix = np.random.choice(elig_slices, size=None) else: sl_pick_ix = np.random.choice(z, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] spatial_shp = data[0].shape assert spatial_shp == seg.shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] data = dutils.pad_nd_image(data, (len(data), *new_shape)) seg = dutils.pad_nd_image(seg, new_shape) # eventual cropping to pre_crop_size: sample pixel from random ROI and shift center, # if possible, to that pixel, so that img still contains ROI after pre-cropping dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))] if np.any(dim_cropflags): # sample pixel from random ROI and shift center, if possible, to that pixel if self.cf.dim==3: choose_fg = np.any(batch_empty_counts/self.batch_size>=self.empty_samples_max_ratio) or \ np.random.rand(1) <= self.p_fg if choose_fg and np.any(seg): available_roi_ids = np.unique(seg)[1:] for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]] if len(elig_roi_ids)>0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) break roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] assert seg[tuple(roi_anchor_pixel)] > 0 # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and # distance to the selected ROI < patch_size /2 def get_cropped_centercoords(dim): low = np.max((self.cf.pre_crop_size[dim] // 2, roi_anchor_pixel[dim] - ( self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim]))) high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2, roi_anchor_pixel[dim] + ( self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim]))) if low >= high: # happens if lesion on the edge of the image. low = self.cf.pre_crop_size[dim] // 2 high = spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2 assert low < high, 'low greater equal high, data dimension {} too small, shp {}, patient {}, low {}, high {}'.format( dim, spatial_shp, patient['pid'], low, high) return np.random.randint(low=low, high=high) else: # sample crop center regardless of ROIs, not guaranteed to be empty def get_cropped_centercoords(dim): return np.random.randint(low=self.cf.pre_crop_size[dim] // 2, high=spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2) sample_seg_center = {} for dim in np.where(dim_cropflags)[0]: sample_seg_center[dim] = get_cropped_centercoords(dim) min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim] // 2) max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim] // 2) data = np.take(data, indices=range(min_, max_), axis=dim + 1) # +1 for channeldim seg = np.take(seg, indices=range(min_, max_), axis=dim) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per observable batch_roi_items[o].append(patient[o]) if self.cf.dim == 3: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero==0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 elif self.cf.dim == 2: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero == 0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'), 'pid': batch_pids, 'roi_counts': batch_roi_counts, 'empty_counts': batch_empty_counts} for key,val in batch_roi_items.items(): #extend batch dic by entries of observables dic batch[key] = np.array(val) return batch