def cuda2np(image): # image: d,h,w if isinstance(image, torch.Tensor): if image.is_cuda: image = image.cpu().detach().numpy() else: image = image.numpy() elif not isinstance(image, np.ndarray): logger.error('image should be torch.Tensor or numpy.ndarray') return image
def cc_augment(config_task, data, seg, patch_type, patch_size, patch_center_dist_from_border=30, do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1, tag=''): # patch_center_dist_from_border should be no more than 1/2 patch size. otherwise code not available. # data: [n,c,d,h,w] # seg: [n,c,d,h,w] dim = len(patch_size) seg_result = None if seg is not None: seg_result = np.zeros([seg.shape[0], seg.shape[1]] + patch_size, dtype=np.float32) data_result = np.zeros([data.shape[0], data.shape[1]] + patch_size, dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] ## for-loop for dim[0] augs = list() for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) # now find a nice center location and extract patch if seg is None: patch_type = 'any' handler = 0 n = 0 while handler == 0: # augmentation modified_coords = False if np.random.uniform() < p_el_per_sample and do_elastic_deform: a = np.random.uniform(alpha[0], alpha[1]) s = np.random.uniform(sigma[0], sigma[1]) coords = elastic_deform_coordinates(coords, a, s) modified_coords = True augs.append('elastic') if np.random.uniform() < p_rot_per_sample and do_rotation: if angle_x[0] == angle_x[1]: a_x = angle_x[0] else: a_x = np.random.uniform(angle_x[0], angle_x[1]) if dim == 3: if angle_y[0] == angle_y[1]: a_y = angle_y[0] else: a_y = np.random.uniform(angle_y[0], angle_y[1]) if angle_z[0] == angle_z[1]: a_z = angle_z[0] else: a_z = np.random.uniform(angle_z[0], angle_z[1]) coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) modified_coords = True augs.append('rotation') if np.random.uniform() < p_scale_per_sample and do_scale: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) coords = scale_coords(coords, sc) modified_coords = True augs.append('scale') # find candidate area for center, the area is cand_point_coord +/- patch_size if patch_type in ['fore', 'small'] and seg is not None: if seg.shape[1] > 1: logger.error('TBD for seg with multiple channels') if patch_type == 'fore': lab_coords = np.where( seg[sample_id, 0, ...] > 0) # lab_coords: tuple elif patch_type == 'small': if config_task.task == 'Task05_Prostate': lab_coords = np.where(seg[sample_id, 0, ...] == 1) else: lab_coords = np.where( seg[sample_id, 0, ...] == config_task.num_class - 1) if len(lab_coords[0]) > 0: # 0 means no such label exists idx = np.random.choice(len(lab_coords[0])) cand_point_coord = [ coords[idx] for coords in lab_coords ] # coords for one random point from 'fore' ground else: cand_point_coord = None if patch_type in ['fore', 'small'] and cand_point_coord is None: ctr_list = None handler = 1 data_result = None seg_result = None augs = None else: ctr_list = list() # coords of the patch center for d in range(dim): if random_crop: if patch_type in ['fore', 'small'] and seg is not None: low = max( patch_center_dist_from_border[d] - 1, cand_point_coord[d] - (patch_size[d] / 2 - 1)) low = int(low) upper = min( cand_point_coord[d] + (patch_size[d] / 2 - 1), data.shape[d + 2] - (patch_center_dist_from_border[d] - 1) ) # +/- patch_size[d] is better but computation costly upper = int(upper) if low == upper: ctr = int(low) elif low < upper: ctr = int(np.random.randint(low, upper)) # if n > 1: # logger.info('n:{}; [low,upper]:{}, ctr:{}'.format(n, str([low, upper]), ctr)) else: logger.error( '(low:{} should be <= upper:{}). patch_type:{}, patch_center_dist_from_border:{}, cand_point_coord:{}, cand point seg value:{}, data.shape:{}, ctr_list:{}' .format( low, upper, str(patch_type), str(patch_center_dist_from_border), str(cand_point_coord), seg[sample_id, 0] + cand_point_coord, str(data.shape), str(ctr_list))) elif patch_type == 'any': if patch_center_dist_from_border[d] == data.shape[ d + 2] - patch_center_dist_from_border[d]: ctr = int(patch_center_dist_from_border[d]) elif patch_center_dist_from_border[d] < data.shape[ d + 2] - patch_center_dist_from_border[d]: ctr = int( np.random.randint( patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d])) else: logger.error( 'low should be <= upper. patch_type:{}, patch_center_dist_from_border:{}, data.shape:{}, ctr_list:{}' .format(str(patch_type), str(patch_center_dist_from_border), str(data.shape), str(ctr_list))) else: # center crop ctr = int(np.round(data.shape[d + 2] / 2.)) ctr_list.append(ctr) # extracting patch if n < 10 and modified_coords: for d in range(dim): coords[d] += ctr_list[d] for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img( data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[1]): seg_result[sample_id, channel_id] = interpolate_img( seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) else: augs = list() if seg is None: s = None else: s = seg[sample_id:sample_id + 1] if random_crop: # margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)] # d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin) d_tmps = list() for channel_id in range(data.shape[1]): d_tmp = utils.extract_roi_from_volume( data[sample_id, channel_id], ctr_list, patch_size, fill="zero") d_tmps.append(d_tmp) d = np.asarray(d_tmps) if seg is not None: s_tmps = list() for channel_id in range(seg.shape[1]): s_tmp = utils.extract_roi_from_volume( seg[sample_id, channel_id], ctr_list, patch_size, fill="zero") s_tmps.append(s_tmp) s = np.asarray(s_tmps) else: d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s) # data_result[sample_id] = d[0] data_result[sample_id] = d if seg is not None: # seg_result[sample_id] = s[0] seg_result[sample_id] = s ## check patch if patch_type in [ 'fore' ]: # cancer could be very very small. so use opproximate method (i.e. use 'fore'). if np.any(seg_result > 0) and np.any(data_result != 0): handler = 1 else: handler = 0 elif patch_type in ['small']: if config_task.task == 'Task05_Prostate': if np.any(seg_result == 1) and np.any( data_result != 0): handler = 1 else: handler = 0 else: if np.any(seg_result == config_task.num_class - 1) and np.any(data_result != 0): handler = 1 else: handler = 0 else: if np.any(data_result != 0): handler = 1 else: handler = 0 n += 1 if n > 5: logger.info( 'tag:{}, patch_type: {}; handler: {}; times: {}; cand point:{}; cand point seg value:{}; ctr_list:{}; data.shape:{}; np.unique(seg_result):{}; np.sum(data_result):{}' .format( tag, patch_type, handler, n, str(cand_point_coord), seg[sample_id, 0, cand_point_coord[0], cand_point_coord[1], cand_point_coord[2]], str(ctr_list), str(data.shape), np.unique(seg_result, return_counts=True), np.sum(data_result))) return data_result, seg_result, augs
def gen_batch(self, batch_size, patch_size): batchImg = np.zeros([ batch_size, self.config_task.num_modality, patch_size[0], patch_size[1], patch_size[2] ]) # n,mod,d,h,w batchLabel = np.zeros( [batch_size, patch_size[0], patch_size[1], patch_size[2]]) # n,d,h,w batchWeight = np.zeros( [batch_size, patch_size[0], patch_size[1], patch_size[2]]) # n,d,h,w batchAugs = list() # import ipdb; ipdb.set_trace() for i in range(batch_size): temp_prob = np.random.uniform() st_time = time.time() handler = 0 while handler == 0: t_wait = 0 if self.trainQueue.qsize() == 0: logger.info( '{} self.trainQueue size = {}, filling....(start time:{})' .format(self.task, self.trainQueue.qsize(), tinies.datestr())) while self.trainQueue.qsize() == 0: time.sleep(1) t_wait += 1 if t_wait > 0: logger.info('{} time to fill self.trainQueue: {}'.format( self.task, t_wait)) patches = self.trainQueue.get() # logger.info('{} trainQueue size:{}'.format(self.task, str(self.trainQueue.qsize()))) if i <= math.ceil( batch_size / 3 ): # nn_unet3d: at least 1/3 samples in a batch contain at least one forground class if temp_prob < self.config_task.small_prob and patches[ 'small'] is not None: patch = patches['small'] handler = 1 elif patches['fore'] is not None: patch = patches['fore'] handler = 1 else: handler = 0 logger.warn('handler={}'.format(handler)) # else for i > math.ceil(batch_size/3) else: if temp_prob < self.config_task.small_prob and patches[ 'small'] is not None: patch = patches['small'] handler = 1 elif 1 - temp_prob < self.config_task.fore_prob and patches[ 'fore'] is not None: patch = patches['fore'] handler = 1 else: patch = patches['any'] handler = 1 if handler == 0: logger.info('handler is 0, going back') if handler == 0: logger.error('handler is 0') # fill in a batch batchImg[i, ...] = patch['image'] batchLabel[i, ...] = patch['label'] batchWeight[i, ...] = patch['weight'] batchAugs.append(patch['augs']) return (batchImg, batchLabel, batchWeight, batchAugs)
for task in args.tasks: config.config_tasks[task] = config.set_config_task(args.trainMode, task, config.base_dir) if args.out_tag: args.out_tag = '_'+args.out_tag #### Prepare datasets with open(os.path.join(os.path.dirname(os.getcwd()), 'fold_splits.json'), mode='r') as f: tasks_archive = json.load(f) # dict: {'Task02_Heart'/...}{'fold index'}{'train'/'val'} # seed np.random.seed(1993) #### prep train if args.trainMode == "independent": logger.error('trainMode should be one of parallel_adapter, shared_adapter') elif args.trainMode != "independent": ### model settings config.patch_size = [128,128,128] config.patch_weights = tinies.calPatchWeights(config.patch_size) config.out_dir = os.path.join(config.out_dir, 'res_{}_{}{}'.format(args.model, args.trainMode, args.out_tag), '_'.join(args.tasks)) tinies.sureDir(config.out_dir) config.eval_out_dir = os.path.join(config.out_dir, "eval_out") tinies.newdir(config.eval_out_dir) config.log_dir = os.path.join(config.out_dir, 'train_log') config.writer = MySummaryWriter(log_dir=config.log_dir) # this will create log_dir logger.set_logger_dir(os.path.join(config.log_dir, 'logger'), action="b") # 'b' reuse log_dir and backup log.log logger.info('--------------------------------Training for {}: {}--------------------------------'.format(args.trainMode, '_'.join(args.tasks)))
def batch_segmentation(config_task, temp_imgs, model): # temp_imgs: mod_num, D, H, W? model_patch_size = config.patch_size # model patch size. if args.trainMode='independent', equal to config_task.patch_size; else, not equal. batch_size = config.batch_size num_class = config_task.num_class patch_weights = torch.from_numpy(config.patch_weights).float().cuda() data_channel, original_D, original_H, original_W = temp_imgs.shape # data_channel = 4 # for some cases, e.g. Task04_Hippocampus. temp_imgs[0] shape is smaller than patch_size.. pad to patch_size. remember to apply the same process to get_train_dataflow() # import ipdb; ipdb.set_trace() temp_imgs, pad_size = tinies.pad2gePatch(temp_imgs, config_task.patch_size, data_channel) data_channel, D, H, W = temp_imgs.shape # temp_prob1 = np.zeros([D, H, W, num_class]) ### before input to model, scale the image with factor of model_patch_size/task_specific_patch_size,so as to unify the patch size to the size required by the universal pipeline model. st_time = time.time() oldShape = [D, H, W] if config.unifyPatch == 'resize': # resize all tasks images to same size for shared/universal model if config.trainMode in ["shared", "universal"]: # tb visualization # colorslist=['#000000','#00FF00','#0000FF','#FF0000', '#FFFF00'] tb_image = temp_imgs[0, ...] slice_indices = [8 * i for i in range(int(tb_image.shape[0] / 8))] img_fig = config.writer.tensor2figure(tb_image, slice_indices, colorslist=config.colorslist, is_label=False, fig_title='image') # config.writer.add_figure('figure/{}_batch_seg_temp_imgs_before_resize2modelpatch'.format(config_task.task), [img_fig], config.step) scale_factors = [ model_patch_size[i] / config_task.patch_size[i] for i in range(len(model_patch_size)) ] newShape = [ int(oldShape[i] * scale_factors[i]) for i in range(len(scale_factors)) ] imgs_list = [] for i in range(temp_imgs.shape[0]): imgs_list.append( skimage.transform.resize(temp_imgs[i], output_shape=tuple(newShape), order=3, mode='constant')) # bi-cubic. temp_imgs = np.asarray(imgs_list) # tb visualization # colorslist=['#000000','#00FF00','#0000FF','#FF0000', '#FFFF00'] tb_image = temp_imgs[0, ...] slice_indices = [8 * i for i in range(int(tb_image.shape[0] / 8))] img_fig = config.writer.tensor2figure(tb_image, slice_indices, colorslist=config.colorslist, is_label=False, fig_title='image') # config.writer.add_figure('figure/{}_batch_seg_temp_imgs_after_resize2modelpatch'.format(config_task.task), [img_fig], config.step) else: raise ValueError('{}: not yet implemented!!'.format(config.unifyPatch)) # logger.info('resize2modelpatch time elapsed:{}'.format(tinies.timer(st_time, time.time()))) data_channel, D, H, W = temp_imgs.shape temp_prob1 = np.zeros([D, H, W, num_class]) data_mini_batch = [] centers = [] st_time = time.time() for patch_center_W in range(int(model_patch_size[2] / 2), W + int(model_patch_size[2] / 2), int(model_patch_size[2] / 2)): patch_center_W = min(patch_center_W, W - int(model_patch_size[2] / 2)) for patch_center_H in range(int(model_patch_size[1] / 2), H + int(model_patch_size[1] / 2), int(model_patch_size[1] / 2)): patch_center_H = min(patch_center_H, H - int(model_patch_size[1] / 2)) for patch_center_D in range(int(model_patch_size[0] / 2), D + int(model_patch_size[0] / 2), int(model_patch_size[0] / 2)): patch_center_D = min(patch_center_D, D - int(model_patch_size[0] / 2)) temp_input_center = [ patch_center_D, patch_center_H, patch_center_W ] # logger.info("temp_input_center:{}".format(temp_input_center)) # ipdb.set_trace() centers.append(temp_input_center) patch = [] for chn in range(data_channel): sub_patch = extract_roi_from_volume(temp_imgs[chn], temp_input_center, model_patch_size, fill="zero") patch.append(sub_patch) patch = np.asanyarray(patch, np.float32) #[mod,d,h,w] # collect to batch data_mini_batch.append(patch) # [14,4,d,h,w] # 4, modalities; if len(data_mini_batch) == batch_size: data_mini_batch = np.asarray(data_mini_batch, np.float32) # data_mini_batch = np.transpose(data_mini_batch, [0, 2, 3, 4, 1]) # batch_size, d, h, w, modality # ipdb.set_trace() data_mini_batch = torch.from_numpy(data_mini_batch).float( ).cuda() # numpy to torch to GPU if config.trainMode == "universal": prob_mini_batch1, share_map, para_map = model( data_mini_batch) else: prob_mini_batch1 = model(data_mini_batch) # if config.test_flip: # prob_mini_batch1 += model(torch.flip(data_mini_batch, [4])) prob_mini_batch1 = prob_mini_batch1.detach() # prob_mini_batch1 = np.transpose(prob_mini_batch1, [0,2,3,4,1]) # n,d,h,w,c prob_mini_batch1 = prob_mini_batch1.permute( [0, 2, 3, 4, 1]) # n,d,h,w,c data_mini_batch = [] for batch_idx in range(prob_mini_batch1.shape[0]): sub_prob = prob_mini_batch1[batch_idx] for i in range(num_class): # sub_prob[...,i] = np.multiply(sub_prob[...,i], config.patch_weights) sub_prob[..., i] = torch.mul(sub_prob[..., i], patch_weights) sub_prob = sub_prob.cpu().numpy() temp_input_center = centers[batch_idx] for c in range(num_class): temp_prob1[..., c] = set_roi_to_volume( temp_prob1[..., c], temp_input_center, sub_prob[..., c]) centers = [] remainder_batch_size = len(data_mini_batch) if remainder_batch_size > 0 and remainder_batch_size < batch_size: # treat the remainder as an idependent batch as it's smaller than batch_size for idx in range(batch_size - len(data_mini_batch)): data_mini_batch.append(np.zeros( [data_channel] + model_patch_size)) # fill to full batch_size with zeros array data_mini_batch = np.asarray(data_mini_batch, np.float32) # data_mini_batch = np.transpose(data_mini_batch, [0, 2, 3, 4, 1]) # batch_size, d, h, w, modality data_mini_batch = torch.from_numpy( data_mini_batch).float().cuda() # numpy to torch to GPU if config.trainMode == "universal": prob_mini_batch1, share_map, para_map = model(data_mini_batch) else: prob_mini_batch1 = model(data_mini_batch) # if config.test_flip: # flip on w axis? # prob_mini_batch1 += model(torch.flip(data_mini_batch, [4])) prob_mini_batch1 = prob_mini_batch1.detach() # prob_mini_batch1 = np.transpose(prob_mini_batch1, [0,2,3,4,1]) prob_mini_batch1 = prob_mini_batch1.permute([0, 2, 3, 4, 1]) # n,d,h,w,c # logger.info('prob_mini_batch1 shape:{}'.format(prob_mini_batch1.shape)) data_mini_batch = [] for batch_idx in range(remainder_batch_size): sub_prob = prob_mini_batch1[batch_idx] # sub_prob = np.reshape(prob_mini_batch1[batch_idx], model_patch_size + [num_class]) for i in range(num_class): # sub_prob[...,i] = np.multiply(sub_prob[...,i], config.patch_weights) sub_prob[..., i] = torch.mul(sub_prob[..., i], patch_weights) sub_prob = sub_prob.cpu().numpy() temp_input_center = centers[batch_idx] for c in range(num_class): temp_prob1[..., c] = set_roi_to_volume(temp_prob1[..., c], temp_input_center, sub_prob[..., c]) elif remainder_batch_size >= batch_size: logger.error( 'the remainder data_mini_batch size is {} and batch_size = {}, code is wrong' .format(len(data_mini_batch), batch_size)) logger.info('patch eval for-loop time elapsed:{}'.format( tinies.timer(st_time, time.time()))) # argmax temp_pred1 = np.argmax(temp_prob1, axis=-1) # temp_pred1 = np.asarray(temp_pred1, dtype=np.uint8) if config.unifyPatch == 'resize': # resize all tasks images to same size for universal model if config.trainMode in ["shared", "universal"]: # tb visualization # colorslist=['#000000','#00FF00','#0000FF','#FF0000', '#FFFF00'] tb_image = temp_imgs[0, ...] tb_pred = temp_pred1 slice_indices = config.writer.chooseSlices(tb_pred) img_fig = config.writer.tensor2figure(tb_image, slice_indices, colorslist=config.colorslist, is_label=False, fig_title='image') pred_fig = config.writer.tensor2figure( tb_pred, slice_indices, colorslist=config.colorslist, is_label=True, fig_title='pred') # config.writer.add_figure('figure/{}_batch_seg_temp_pred1_before_resize2originalScale'.format(config_task.task), [img_fig, pred_fig], config.step) # reisze temp_pred1 = temp_pred1.astype( np.float32 ) # it will result in nothing if input an array of np.uint8 to resize(order=0) temp_pred1 = skimage.transform.resize(temp_pred1, output_shape=tuple(oldShape), order=0, mode='constant') # nearest. # tb visualization # colorslist=['#000000','#00FF00','#0000FF','#FF0000', '#FFFF00'] tb_image = temp_imgs[0, ...] tb_pred = temp_pred1 slice_indices = config.writer.chooseSlices(tb_pred) img_fig = config.writer.tensor2figure(tb_image, slice_indices, colorslist=config.colorslist, is_label=False, fig_title='image') pred_fig = config.writer.tensor2figure( tb_pred, slice_indices, colorslist=config.colorslist, is_label=True, fig_title='pred') # config.writer.add_figure('figure/{}_batch_seg_temp_pred1_after_resize2originalScale'.format(config_task.task), [img_fig, pred_fig], config.step) else: raise ValueError('{}: not yet implemented!!'.format(config.unifyPatch)) temp_pred1 = np.asarray(temp_pred1, dtype=np.uint8) # for some cases, e.g. Task04_Hippocampus. temp_imgs[0] shape is smaller than model_patch_size.. here use crop to recover to original shape. if np.any(pad_size): temp_pred1 = temp_pred1[pad_size[0]:(original_D + pad_size[0]), pad_size[1]:(original_H + pad_size[1]), pad_size[2]:(original_W + pad_size[2])] return temp_pred1