def tensor2figure( self, image, selected_z_list, colorslist=['#000000', '#00FF00', '#0000FF', '#FF0000', '#FFFF00'], is_label=False, fig_title=''): ''' image: d, h, w colorslist: hex colors. steps: see train.tb_images(). step1: selected_z_list = self.chooseSlices() step2: fig = self.tensor2figure(...,selected_z_list,...) step3: self.add_figure(fig) ''' image = cuda2np(image) try: fig = imagelist2figure(image, selected_z_list, colorslist=colorslist, is_label=is_label, fig_title=fig_title) except Exception as e: logger.info('{}'.format(str(e))) # logger.info('image.shape:{}, selected_z_list:{}'.format(str(image.shape), str(selected_z_list))) return fig
def __init__(self, nb_tasks, inChans, outChans, kernel_size=3, stride=1, padding=1, second=0): super(conv_unit, self).__init__() self.stride = stride if self.stride != 1: self.conv = nn.Conv3d(inChans, outChans, kernel_size=kernel_size, stride=stride, padding=padding) # padding != 0 for stride != 2 if doing padding=SAME. elif self.stride == 1: if config.trainMode != 'universal': # independent, shared self.conv = nn.Conv3d(inChans, outChans, kernel_size=kernel_size, stride=stride, padding=padding) # padding != 0 for stride != 2 if doing padding=SAME. else: if config.module in ['series_adapter', 'parallel_adapter']: self.conv = nn.Conv3d(inChans, outChans, kernel_size=kernel_size, stride=stride, padding=padding) # padding != 0 for stride != 2 if doing padding=SAME. if config.module == 'series_adapter': self.adapOps = nn.ModuleList([conv1x1(outChans) for i in range(nb_tasks)]) # based on https://github.com/srebuffi/residual_adapters/ elif config.module == 'parallel_adapter': self.adapOps = nn.ModuleList([conv1x1(inChans, outChans) for i in range(nb_tasks)]) else: pass elif config.module == 'separable_adapter': logger.info('using module of :{}'.format(config.module)) self.adapOps = nn.ModuleList([dwise(inChans) for i in range(nb_tasks)]) self.pwise = pwise(inChans, outChans) else: pass self.op = nn.ModuleList([norm_act(outChans, only='norm') for i in range(nb_tasks)])
def eval(args, tasks_archive, model, eval_epoch, iterations): tasks = args.tasks # list model.eval() for task_idx in range(len(tasks)): config.task_idx = task_idx # needed for u2net3d(). task = tasks[task_idx] config_task = config.config_tasks[task] st_time = time.time() # evaluating. # tensorboard visualization of eval embedded. dices = evaluate.evaluate(config_task, tasks_archive[task]['fold' + str(args.fold)]['val'], model, epoch_num=eval_epoch, outdir=config.eval_out_dir) fo = open(os.path.join(config.eval_out_dir, '{}_eval_res.csv'.format(args.trainMode)), mode='a+') wo = csv.writer(fo, delimiter=',') for k, v in dices.items(): config.writer.add_scalar('data/dices/{}_{}'.format(task, k), v, iterations) wo.writerow([ args.trainMode, task, eval_epoch, config.step_per_epoch, k, v, tinies.datestr() ]) fo.flush() logger.info('Eval time elapsed:{}'.format( tinies.timer(st_time, time.time())))
def post_processing(config_task, pred_raw, temp_weight=None, ID=''): struct = ndimage.generate_binary_structure(3, 2) margin = 5 wt_threshold = None if temp_weight is None: temp_weight = np.ones_like(pred_raw) pred_raw = pred_raw * temp_weight out_label = np.zeros_like( pred_raw, dtype=np.uint8 ) # by Chao. VERY IMPORTANT TO AVOID out_label to be forced to 0/1 array for i in range(1, config_task.num_class): pred_tmp = np.zeros_like(pred_raw) pred_tmp[pred_raw == i] = 1 if i == config_task.num_class - 1 and any([ x in str(config_task.labels[str(i)]).lower() for x in ['cancer', 'tumour'] ]): out_label[ pred_raw == i] = i # don't apply get_largest_two_component to the highest level class (e.g. cancer)# some cases like liver cancer, there could be multiple tumors in one liver. else: pred_tmp = ndimage.morphology.binary_closing(pred_tmp, structure=struct) try: if config_task.task in [ 'Task02_Heart', 'Task03_Liver', 'Task07_Pancreas', 'Task09_Spleen' ]: if config_task.task in ['Task09_Spleen']: pred_tmp[..., :, int(pred_raw.shape[-1] / 2)::] = 0 pred_tmp = get_largest_one_component( pred_tmp, wt_threshold, ID + '_label' + str(i)) else: pred_tmp = get_largest_two_component( pred_tmp, wt_threshold, ID + '_label' + str(i)) except: logger.info(' class:{}, np.uniques(pred_raw):{}'.format( i, str(np.unique(pred_raw, return_counts=True)))) # import ipdb; ipdb.set_trace() out_label[pred_tmp == 1] = i return out_label
def segment_one_image(config_task, data, model, ID=''): """ perform inference and unpad the volume to original shape """ im = data['image'] # d,h,w,mod temp_weight = data['weight'][:, :, :, 0] # d,h,w original_shape = data[ 'original_shape'] # original_shape, before cropping and resampling temp_bbox = data['bbox'] im_path = data['im_path'] im = im[np.newaxis, ...] # add batch dim im2pred = np.transpose( im[0], [3, 0, 1, 2]) # mod, d, h, w # only one batch? by Chao. st_time = time.time() pred1 = batch_segmentation(config_task, im2pred, model) logger.info('batch_segmentation time elapsed:{}'.format( tinies.timer(st_time, time.time()))) if config.post_processing: st_time = time.time() out_label = post_processing(config_task, pred1, temp_weight, ID) logger.info('post_processing time elapsed:{}'.format( tinies.timer(st_time, time.time()))) out_label = np.asarray(out_label, np.int16) else: out_label = np.asarray(pred1, np.int16) st_time = time.time() final_label = np.zeros(original_shape, np.int16) # d,h,w final_label = set_ND_volume_roi_with_bounding_box_range( config_task, final_label, temp_bbox[0], temp_bbox[1], out_label, sitk.sitkNearestNeighbor, im_path) logger.info('set_ND_volume_roi time elapsed:{}'.format( tinies.timer(st_time, time.time()))) return final_label
def cc_augment(config_task, data, seg, patch_type, patch_size, patch_center_dist_from_border=30, do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1, tag=''): # patch_center_dist_from_border should be no more than 1/2 patch size. otherwise code not available. # data: [n,c,d,h,w] # seg: [n,c,d,h,w] dim = len(patch_size) seg_result = None if seg is not None: seg_result = np.zeros([seg.shape[0], seg.shape[1]] + patch_size, dtype=np.float32) data_result = np.zeros([data.shape[0], data.shape[1]] + patch_size, dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): patch_center_dist_from_border = dim * [patch_center_dist_from_border] ## for-loop for dim[0] augs = list() for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) # now find a nice center location and extract patch if seg is None: patch_type = 'any' handler = 0 n = 0 while handler == 0: # augmentation modified_coords = False if np.random.uniform() < p_el_per_sample and do_elastic_deform: a = np.random.uniform(alpha[0], alpha[1]) s = np.random.uniform(sigma[0], sigma[1]) coords = elastic_deform_coordinates(coords, a, s) modified_coords = True augs.append('elastic') if np.random.uniform() < p_rot_per_sample and do_rotation: if angle_x[0] == angle_x[1]: a_x = angle_x[0] else: a_x = np.random.uniform(angle_x[0], angle_x[1]) if dim == 3: if angle_y[0] == angle_y[1]: a_y = angle_y[0] else: a_y = np.random.uniform(angle_y[0], angle_y[1]) if angle_z[0] == angle_z[1]: a_z = angle_z[0] else: a_z = np.random.uniform(angle_z[0], angle_z[1]) coords = rotate_coords_3d(coords, a_x, a_y, a_z) else: coords = rotate_coords_2d(coords, a_x) modified_coords = True augs.append('rotation') if np.random.uniform() < p_scale_per_sample and do_scale: if np.random.random() < 0.5 and scale[0] < 1: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) coords = scale_coords(coords, sc) modified_coords = True augs.append('scale') # find candidate area for center, the area is cand_point_coord +/- patch_size if patch_type in ['fore', 'small'] and seg is not None: if seg.shape[1] > 1: logger.error('TBD for seg with multiple channels') if patch_type == 'fore': lab_coords = np.where( seg[sample_id, 0, ...] > 0) # lab_coords: tuple elif patch_type == 'small': if config_task.task == 'Task05_Prostate': lab_coords = np.where(seg[sample_id, 0, ...] == 1) else: lab_coords = np.where( seg[sample_id, 0, ...] == config_task.num_class - 1) if len(lab_coords[0]) > 0: # 0 means no such label exists idx = np.random.choice(len(lab_coords[0])) cand_point_coord = [ coords[idx] for coords in lab_coords ] # coords for one random point from 'fore' ground else: cand_point_coord = None if patch_type in ['fore', 'small'] and cand_point_coord is None: ctr_list = None handler = 1 data_result = None seg_result = None augs = None else: ctr_list = list() # coords of the patch center for d in range(dim): if random_crop: if patch_type in ['fore', 'small'] and seg is not None: low = max( patch_center_dist_from_border[d] - 1, cand_point_coord[d] - (patch_size[d] / 2 - 1)) low = int(low) upper = min( cand_point_coord[d] + (patch_size[d] / 2 - 1), data.shape[d + 2] - (patch_center_dist_from_border[d] - 1) ) # +/- patch_size[d] is better but computation costly upper = int(upper) if low == upper: ctr = int(low) elif low < upper: ctr = int(np.random.randint(low, upper)) # if n > 1: # logger.info('n:{}; [low,upper]:{}, ctr:{}'.format(n, str([low, upper]), ctr)) else: logger.error( '(low:{} should be <= upper:{}). patch_type:{}, patch_center_dist_from_border:{}, cand_point_coord:{}, cand point seg value:{}, data.shape:{}, ctr_list:{}' .format( low, upper, str(patch_type), str(patch_center_dist_from_border), str(cand_point_coord), seg[sample_id, 0] + cand_point_coord, str(data.shape), str(ctr_list))) elif patch_type == 'any': if patch_center_dist_from_border[d] == data.shape[ d + 2] - patch_center_dist_from_border[d]: ctr = int(patch_center_dist_from_border[d]) elif patch_center_dist_from_border[d] < data.shape[ d + 2] - patch_center_dist_from_border[d]: ctr = int( np.random.randint( patch_center_dist_from_border[d], data.shape[d + 2] - patch_center_dist_from_border[d])) else: logger.error( 'low should be <= upper. patch_type:{}, patch_center_dist_from_border:{}, data.shape:{}, ctr_list:{}' .format(str(patch_type), str(patch_center_dist_from_border), str(data.shape), str(ctr_list))) else: # center crop ctr = int(np.round(data.shape[d + 2] / 2.)) ctr_list.append(ctr) # extracting patch if n < 10 and modified_coords: for d in range(dim): coords[d] += ctr_list[d] for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img( data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) if seg is not None: for channel_id in range(seg.shape[1]): seg_result[sample_id, channel_id] = interpolate_img( seg[sample_id, channel_id], coords, order_seg, border_mode_seg, cval=border_cval_seg, is_seg=True) else: augs = list() if seg is None: s = None else: s = seg[sample_id:sample_id + 1] if random_crop: # margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)] # d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin) d_tmps = list() for channel_id in range(data.shape[1]): d_tmp = utils.extract_roi_from_volume( data[sample_id, channel_id], ctr_list, patch_size, fill="zero") d_tmps.append(d_tmp) d = np.asarray(d_tmps) if seg is not None: s_tmps = list() for channel_id in range(seg.shape[1]): s_tmp = utils.extract_roi_from_volume( seg[sample_id, channel_id], ctr_list, patch_size, fill="zero") s_tmps.append(s_tmp) s = np.asarray(s_tmps) else: d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s) # data_result[sample_id] = d[0] data_result[sample_id] = d if seg is not None: # seg_result[sample_id] = s[0] seg_result[sample_id] = s ## check patch if patch_type in [ 'fore' ]: # cancer could be very very small. so use opproximate method (i.e. use 'fore'). if np.any(seg_result > 0) and np.any(data_result != 0): handler = 1 else: handler = 0 elif patch_type in ['small']: if config_task.task == 'Task05_Prostate': if np.any(seg_result == 1) and np.any( data_result != 0): handler = 1 else: handler = 0 else: if np.any(seg_result == config_task.num_class - 1) and np.any(data_result != 0): handler = 1 else: handler = 0 else: if np.any(data_result != 0): handler = 1 else: handler = 0 n += 1 if n > 5: logger.info( 'tag:{}, patch_type: {}; handler: {}; times: {}; cand point:{}; cand point seg value:{}; ctr_list:{}; data.shape:{}; np.unique(seg_result):{}; np.sum(data_result):{}' .format( tag, patch_type, handler, n, str(cand_point_coord), seg[sample_id, 0, cand_point_coord[0], cand_point_coord[1], cand_point_coord[2]], str(ctr_list), str(data.shape), np.unique(seg_result, return_counts=True), np.sum(data_result))) return data_result, seg_result, augs
logger.error('trainMode should be one of parallel_adapter, shared_adapter') elif args.trainMode != "independent": ### model settings config.patch_size = [128,128,128] config.patch_weights = tinies.calPatchWeights(config.patch_size) config.out_dir = os.path.join(config.out_dir, 'res_{}_{}{}'.format(args.model, args.trainMode, args.out_tag), '_'.join(args.tasks)) tinies.sureDir(config.out_dir) config.eval_out_dir = os.path.join(config.out_dir, "eval_out") tinies.newdir(config.eval_out_dir) config.log_dir = os.path.join(config.out_dir, 'train_log') config.writer = MySummaryWriter(log_dir=config.log_dir) # this will create log_dir logger.set_logger_dir(os.path.join(config.log_dir, 'logger'), action="b") # 'b' reuse log_dir and backup log.log logger.info('--------------------------------Training for {}: {}--------------------------------'.format(args.trainMode, '_'.join(args.tasks))) # instantialize model inChans_list = [config.config_tasks[task].num_modality for task in args.tasks] # input num_modality num_class_list = [config.config_tasks[task].num_class for task in args.tasks] model = u2net3d.u2net3d(inChans_list=inChans_list, base_outChans=config.base_outChans, num_class_list=num_class_list) torch.manual_seed(1) model.apply(train.weights_init) # if transfer learning # Load checkpoint and initialize the networks with the weights of a pretrained network if args.ckp != '' and args.resume_ckp == '': logger.info('==> Transferring from checkpoint: {}, loading checkpoint.....'.format(args.ckp)) checkpoint = torch.load(args.ckp) model_old = checkpoint['model']
config.patch_size = config.patch_sizes[task] config.patch_weights = tinies.calPatchWeights(config.patch_size) config.batch_size = config.batch_sizes[task] config.num_pool_per_axis = config.nums_pool_per_axis[task] config.base_outChans = config.base_outChanss[task] config.val_epoch = config.val_epochs[task] config.out_dir = os.path.join(config.out_dir, 'res_{}_{}{}'.format(args.model, args.trainMode, args.out_tag), task) tinies.sureDir(config.out_dir) config.eval_out_dir = os.path.join(config.out_dir, "eval_out") tinies.newdir(config.eval_out_dir) config.log_dir = os.path.join(config.out_dir, 'train_log') config.writer = MySummaryWriter(log_dir=config.log_dir) # this will create log_dir logger.set_logger_dir(os.path.join(config.log_dir, 'logger'), action="b") # 'b' reuse log_dir and backup log.log logger.info('--------------------------------Training for {}: {}--------------------------------'.format(args.trainMode, task)) # ??print for each task? # instantialize model inChans_list = [config.config_tasks[task].num_modality] # input num_modality num_class_list = [config.config_tasks[task].num_class] model = u2net3d.u2net3d(inChans_list=inChans_list, base_outChans=config.base_outChans, num_class_list=num_class_list) torch.manual_seed(1) model.apply(train.weights_init) train.train(args, tasks_archive, model) elif args.trainMode != "independent": ### model settings config.patch_size = [128,128,128] config.patch_weights = tinies.calPatchWeights(config.patch_size) # config.batch_size = 2
def __init__(self, inChans_list=[2], base_outChans=16, num_class_list=[4]): ''' Args: One or more tasks could be input at once. So lists of inital model settings are passed. inChans_list: a list of num_modality for each input task. base_outChans: outChans of the inputTransition, i.e. inChans of the first layer of the shared backbone of the universal model. depth: depth of the shared backbone. ''' logger.info('------- base_outChans is {}'.format(base_outChans)) super(u2net3d, self).__init__() nb_tasks = len(num_class_list) self.depth = max( config.num_pool_per_axis ) + 1 # config.num_pool_per_axis firstly defined in train_xxxx.py or main.py stride_sizes = num_pool2stride_size(config.num_pool_per_axis) self.in_tr_list = nn.ModuleList([ InputTransition(inChans_list[j], base_outChans) for j in range(nb_tasks) ]) # task-specific input layers outChans_list = list() self.down_blocks = nn.ModuleList( ) # # register modules from regular python list. self.down_samps = nn.ModuleList() self.down_pads = list() # used to pad as padding='same' in tensorflow inChans = base_outChans for i in range(self.depth): outChans = base_outChans * (2**i) outChans_list.append(outChans) self.down_blocks.append( DownBlock(nb_tasks, inChans, outChans, kernel_size=3, stride=1, padding=1)) if i != self.depth - 1: # stride for each axis could be 1 or 2, depending on tasks. # to apply padding='SAME' as tensorflow, cal and save pad num to manually pad in forward(). pads = list( ) # 6 elements for one 3-D volume. originized for last dim backward to first dim, e.g. w,w,h,h,d,d # required for F.pad. # pad 1 to the right end if s=2 else pad 1 to both ends (s=1). for j in stride_sizes[i][::-1]: if j == 2: pads.extend([0, 1]) elif j == 1: pads.extend([1, 1]) self.down_pads.append(pads) self.down_samps.append( DownSample(nb_tasks, outChans, outChans * 2, kernel_size=3, stride=tuple(stride_sizes[i]), padding=0)) inChans = outChans * 2 else: inChans = outChans self.up_samps = nn.ModuleList([None] * (self.depth - 1)) self.up_blocks = nn.ModuleList([None] * (self.depth - 1)) self.dSupers = nn.ModuleList( ) # 1 elements if self.depth =2, or 2 elements if self.depth >= 3 for i in range(self.depth - 2, -1, -1): self.up_samps[i] = UnetUpsample(nb_tasks, inChans, outChans_list[i], up_stride=stride_sizes[i]) self.up_blocks[i] = UpBlock(nb_tasks, outChans_list[i] * 2, outChans_list[i], kernel_size=3, stride=1, padding=1) if config.deep_supervision and i < 3 and i > 0: self.dSupers.append( nn.ModuleList([ DeepSupervision(outChans_list[i], num_class_list[j], up_stride=tuple(stride_sizes[i - 1])) for j in range(nb_tasks) ])) inChans = outChans_list[i] self.out_tr_list = nn.ModuleList([ OutputTransition(inChans, num_class_list[j]) for j in range(nb_tasks) ])
def gen_batch(self, batch_size, patch_size): batchImg = np.zeros([ batch_size, self.config_task.num_modality, patch_size[0], patch_size[1], patch_size[2] ]) # n,mod,d,h,w batchLabel = np.zeros( [batch_size, patch_size[0], patch_size[1], patch_size[2]]) # n,d,h,w batchWeight = np.zeros( [batch_size, patch_size[0], patch_size[1], patch_size[2]]) # n,d,h,w batchAugs = list() # import ipdb; ipdb.set_trace() for i in range(batch_size): temp_prob = np.random.uniform() st_time = time.time() handler = 0 while handler == 0: t_wait = 0 if self.trainQueue.qsize() == 0: logger.info( '{} self.trainQueue size = {}, filling....(start time:{})' .format(self.task, self.trainQueue.qsize(), tinies.datestr())) while self.trainQueue.qsize() == 0: time.sleep(1) t_wait += 1 if t_wait > 0: logger.info('{} time to fill self.trainQueue: {}'.format( self.task, t_wait)) patches = self.trainQueue.get() # logger.info('{} trainQueue size:{}'.format(self.task, str(self.trainQueue.qsize()))) if i <= math.ceil( batch_size / 3 ): # nn_unet3d: at least 1/3 samples in a batch contain at least one forground class if temp_prob < self.config_task.small_prob and patches[ 'small'] is not None: patch = patches['small'] handler = 1 elif patches['fore'] is not None: patch = patches['fore'] handler = 1 else: handler = 0 logger.warn('handler={}'.format(handler)) # else for i > math.ceil(batch_size/3) else: if temp_prob < self.config_task.small_prob and patches[ 'small'] is not None: patch = patches['small'] handler = 1 elif 1 - temp_prob < self.config_task.fore_prob and patches[ 'fore'] is not None: patch = patches['fore'] handler = 1 else: patch = patches['any'] handler = 1 if handler == 0: logger.info('handler is 0, going back') if handler == 0: logger.error('handler is 0') # fill in a batch batchImg[i, ...] = patch['image'] batchLabel[i, ...] = patch['label'] batchWeight[i, ...] = patch['weight'] batchAugs.append(patch['augs']) return (batchImg, batchLabel, batchWeight, batchAugs)
def train(args, tasks_archive, model): torch.backends.cudnn.benchmark = True if args.resume_ckp != '': logger.info('==> loading checkpoint: {}'.format(args.ckp)) checkpoint = torch.load(args.resume_ckp) model = nn.parallel.DataParallel(model) logger.info(' + model num_params: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) if config.use_gpu: model.cuda() # required bofore optimizer? # cudnn.benchmark = True print(model) # especially useful for debugging model structure. # summary(model, input_size=tuple([config.num_modality]+config.patch_size)) # takes some time. comment during debugging. ouput each layer's out shape. # for name, m in model.named_modules(): # logger.info('module name:{}'.format(name)) # print(m) # lr lr = config.base_lr if args.resume_ckp != '': optimizer = checkpoint['optimizer'] else: optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.weight_decay) # # loss dice_loss = MulticlassDiceLoss() ce_loss = nn.CrossEntropyLoss() focal_loss = FocalLoss(gamma=2) # prep data tasks = args.tasks # list tb_loaders = list() # train batch loader len_loader = list() for task in tasks: tb_loader = tb_load(task) tb_loader.enQueue(tasks_archive[task]['fold' + str(args.fold)], config.patch_size) tb_loaders.append(tb_loader) len_loader.append(len(tb_loader)) min_len_loader = np.min(len_loader) # init train values if args.resume_ckp != '': trLoss_queue = checkpoint['trLoss_queue'] last_trLoss_ma = checkpoint['last_trLoss_ma'] else: trLoss_queue = deque( maxlen=config.trLoss_win ) # queue to store exponential moving average of total loss in last N epochs last_trLoss_ma = None # the previous one. trLoss_queue_list = [ deque(maxlen=config.trLoss_win) for i in range(len(tasks)) ] last_trLoss_ma_list = [None] * len(tasks) trLoss_ma_list = [None] * len(tasks) if args.resume_epoch > 0: start_epoch = args.resume_epoch + 1 iterations = args.resume_epoch * config.step_per_epoch + 1 else: start_epoch = 1 iterations = 1 logger.info('start epoch: {}'.format(start_epoch)) ## run train for epoch in range(start_epoch, config.max_epoch + 1): logger.info(' ----- training epoch {} -----'.format(epoch)) epoch_st_time = time.time() model.train() loss_epoch = 0.0 loss_epoch_list = [0] * len(tasks) num_batch_processed = 0 # growing num_batch_processed_list = [0] * len(tasks) for step in tqdm(range(config.step_per_epoch), desc='{}: epoch{}'.format(args.trainMode, epoch)): config.step = iterations config.task_idx = (iterations - 1) % len(tasks) config.task = tasks[config.task_idx] # import ipdb; ipdb.set_trace() # tb show lr config.writer.add_scalar('data/lr', lr, iterations - 1) st_time = time.time() for idx in range(len(tasks)): tb_loaders[idx].check_process() # import ipdb; ipdb.set_trace() (batchImg, batchLabel, batchWeight, batchAugs) = tb_loaders[config.task_idx].gen_batch( config.batch_size, config.patch_size) # logger.info('idx{}_{}, gen_batch time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time()))) st_time = time.time() batchImg = torch.from_numpy(batchImg).float( ) # change all inputs to same torch tensor type batchLabel = torch.from_numpy(batchLabel).float() batchWeight = torch.from_numpy(batchWeight).float() if config.use_gpu: batchImg = batchImg.cuda() batchLabel = batchLabel.cuda() batchWeight = batchWeight.cuda() # logger.info('idx{}_{}, .cuda time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time()))) optimizer.zero_grad() st_time = time.time() if config.trainMode in ["universal"]: output, share_map, para_map = model(batchImg) else: output = model(batchImg) # logger.info('idx{}_{}, model() time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time()))) st_time = time.time() # tensorboard visualization of training for i in range(len(tasks)): if iterations > 200 and iterations % 1000 == i: tb_images([ batchImg[0, 0, ...], batchLabel[0, ...], torch.argmax(output[0, ...], dim=0) ], [False, True, True], ['image', 'GT', 'PS'], iterations, tag='Train_idx{}_{}_batch{}_{}'.format( config.task_idx, config.task, 0, '_'.join(batchAugs[0]))) tb_images([ batchImg[config.batch_size - 1, 0, ...], batchLabel[config.batch_size - 1, ...], torch.argmax(output[config.batch_size - 1, ...], dim=0) ], [False, True, True], ['image', 'GT', 'PS'], iterations, tag='Train_idx{}_{}_batch{}_{}_step{}'.format( config.task_idx, config.task, config.batch_size - 1, '_'.join(batchAugs[config.batch_size - 1]), iterations - 1)) if config.trainMode == "universal": logger.info( 'share_map shape:{}, para_map shape:{}'.format( str(share_map.shape), str(para_map.shape))) tb_images([ para_map[0, :, 64, ...], share_map[0, :, 64, ...] ], [False, False], ['last_para_map', 'last_share_map'], iterations, tag='Train_idx{}_{}_para_share_maps_channels' .format(config.task_idx, config.task)) logger.info( '----- {}, train epoch {} time elapsed:{} -----'.format( config.task, epoch, tinies.timer(epoch_st_time, time.time()))) st_time = time.time() output_softmax = F.softmax(output, dim=1) loss = lovasz_softmax(output_softmax, batchLabel, ignore=10) + focal_loss(output, batchLabel) loss.backward() optimizer.step() # logger.info('idx{}_{}, backward time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time()))) # loss.data.item() config.writer.add_scalar('data/loss_step', loss.item(), iterations) config.writer.add_scalar( 'data/loss_step_idx{}_{}'.format(config.task_idx, config.task), loss.item(), iterations) loss_epoch += loss.item() num_batch_processed += 1 loss_epoch_list[config.task_idx] += loss.item() num_batch_processed_list[config.task_idx] += 1 iterations += 1 # import ipdb; ipdb.set_trace() if epoch % config.save_epoch == 0: ckp_path = os.path.join( config.log_dir, '{}_{}_epoch{}_{}.pth.tar'.format(args.trainMode, '_'.join(args.tasks), epoch, tinies.datestr())) torch.save( { 'epoch': epoch, 'model': model, 'model_state_dict': model.state_dict(), 'optimizer': optimizer, 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'trLoss_queue': trLoss_queue, 'last_trLoss_ma': last_trLoss_ma }, ckp_path) loss_epoch /= num_batch_processed config.writer.add_scalar('data/loss_epoch', loss_epoch, iterations - 1) for idx in range(len(tasks)): task = tasks[idx] loss_epoch_list[idx] /= num_batch_processed_list[idx] config.writer.add_scalar( 'data/loss_epoch_idx{}_{}'.format(idx, task), loss_epoch_list[idx], iterations - 1) # import ipdb; ipdb.set_trace() ### lr decay trLoss_queue.append(loss_epoch) trLoss_ma = np.asarray(trLoss_queue).mean( ) # moving average. What about exponential moving average config.writer.add_scalar('data/trLoss_ma', trLoss_ma, iterations - 1) for idx in range(len(tasks)): task = tasks[idx] trLoss_queue_list[idx].append(loss_epoch_list[idx]) trLoss_ma_list[idx] = np.asarray(trLoss_queue_list[idx]).mean( ) # moving average. What about exponential moving average config.writer.add_scalar( 'data/trLoss_ma_idx{}_{}'.format(idx, task), trLoss_ma_list[idx], iterations - 1) # import ipdb; ipdb.set_trace() #### online eval Eval_bool = False if epoch >= config.start_val_epoch and epoch % config.val_epoch == 0: Eval_bool = True elif lr < 1e-8: Eval_bool = True logger.info( 'lr is reduced to {}. Will do the last evaluation for all samples!' .format(lr)) else: pass # if epoch >= config.start_val_epoch and epoch % config.val_epoch == 0: if Eval_bool: eval(args, tasks_archive, model, epoch, iterations - 1) ## stop if lr is too low if lr < 1e-8: logger.info('lr is reduced to {}. Job Done!'.format(lr)) break ###### lr decay based on current task if len(trLoss_queue) == trLoss_queue.maxlen: if last_trLoss_ma and last_trLoss_ma - trLoss_ma < 1e-4: # 5e-3 lr /= 2 for param_group in optimizer.param_groups: param_group['lr'] = lr last_trLoss_ma = trLoss_ma ## save model when lr < 1e-8 if lr < 1e-8: ckp_path = os.path.join( config.log_dir, '{}_{}_epoch{}_{}.pth.tar'.format(args.trainMode, '_'.join(args.tasks), epoch, tinies.datestr())) torch.save( { 'epoch': epoch, 'model': model, 'model_state_dict': model.state_dict(), 'optimizer': optimizer, 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'trLoss_queue': trLoss_queue, 'last_trLoss_ma': last_trLoss_ma }, ckp_path)
def trQueue(config_task, ids, dataQueue, patch_size, nProc=1, seed=1): # data, center_bboxes ''' args: file_generator to get data: dict of: 'image', mod, d,h,w; 'label',d,h,w;'weight',d,h,w. center_bboxes: dict of 'small','fore', 'any'. # batch_num: num of batches to extract returns: queue of batches. ''' patch_size = config_task.patch_size files = load_files(ids) # final_files = files[1:20] # debug max_repeats = math.ceil(config.step_per_epoch * config.max_epoch / (len(files) * config_task.num_patch_per_file)) # max_repeats = 10 final_files = [] # np.random.seed(1) np.random.seed(seed) for i in range(max_repeats): np.random.shuffle(files) final_files.extend(files) datDir = os.path.join(config.prepData_dir, config_task.task, "Tr") for obj in final_files: ID = obj['id'] st_time = time.time() try: t_wait = 0 while dataQueue.qsize() == config_task.queue_size: time.sleep(1) t_wait += 1 if t_wait > 0: logger.info( '{} queue is full, size={}, time waited for full:{}'. format(config_task.task, config_task.queue_size, t_wait)) # ID = 'prostate_16' # debugging. # tinies.ForkedPdb().set_trace() volumes = np.load( os.path.join(datDir, ID + '_volumes.npy') ) #mod, d, h, w # the largest liver case "liver_22_volumes.npy" costs 0.6s # volume_list = [volumes[i] for i in range(volumes.shape[0])] label = np.load(os.path.join( datDir, ID + '_label.npy')) # also works for NoneType obj. weight = np.load(os.path.join(datDir, ID + '_weight.npy')) # logger.info('ID:{}; load .npy time elapsed:{}'.format(ID, tinies.timer(st_time, time.time()))) st_time = time.time() v_shape = volumes.shape l_shape = label.shape # for tasks like Task04_Hippocampus, some images smaller than patch_size, padding to patch_size, during eval, after CNN output, use crop to recover to original size. volume_list = [] for moda in range(volumes.shape[0]): sub_vol, pad_size = tinies.pad2gePatch(volumes[moda], config_task.patch_size, data_channel=None) volume_list.append(sub_vol) volumes = np.asarray(volume_list) label, pad_size = tinies.pad2gePatch( label, config_task.patch_size, data_channel=None ) # could be changed to pad symmetrically instead of asymmetrically. TBD. weight, pad_size = tinies.pad2gePatch(weight, config_task.patch_size, data_channel=None) assert all( [i == j for i, j in zip(volumes.shape[1:], label.shape[0:])] ), "ID:{}, before pad, volumes shape:{}, label shape:{}; after pad: volumes shape:{}, label shape:{}".format( ID, str(v_shape), str(l_shape), str(volumes.shape), str(label.shape)) data = dict() data['ID'] = ID data['image'] = volumes data['label'] = label data['weight'] = weight for i in range(config_task.num_patch_per_file): patches = dict() patches['any'] = augment_patch(config_task, data, config.patch_size, ptype='any') patches['fore'] = augment_patch(config_task, data, config.patch_size, ptype='fore') patches['small'] = augment_patch(config_task, data, config.patch_size, ptype='small') # patches['small'] = None dataQueue.put(patches) except Exception as e: logger.info('error in for-loop of trQueue:{}'.format(str(e)))
def batch_segmentation(config_task, temp_imgs, model): # temp_imgs: mod_num, D, H, W? model_patch_size = config.patch_size # model patch size. if args.trainMode='independent', equal to config_task.patch_size; else, not equal. batch_size = config.batch_size num_class = config_task.num_class patch_weights = torch.from_numpy(config.patch_weights).float().cuda() data_channel, original_D, original_H, original_W = temp_imgs.shape # data_channel = 4 # for some cases, e.g. Task04_Hippocampus. temp_imgs[0] shape is smaller than patch_size.. pad to patch_size. remember to apply the same process to get_train_dataflow() # import ipdb; ipdb.set_trace() temp_imgs, pad_size = tinies.pad2gePatch(temp_imgs, config_task.patch_size, data_channel) data_channel, D, H, W = temp_imgs.shape # temp_prob1 = np.zeros([D, H, W, num_class]) ### before input to model, scale the image with factor of model_patch_size/task_specific_patch_size,so as to unify the patch size to the size required by the universal pipeline model. st_time = time.time() oldShape = [D, H, W] if config.unifyPatch == 'resize': # resize all tasks images to same size for shared/universal model if config.trainMode in ["shared", "universal"]: # tb visualization # colorslist=['#000000','#00FF00','#0000FF','#FF0000', '#FFFF00'] tb_image = temp_imgs[0, ...] slice_indices = [8 * i for i in range(int(tb_image.shape[0] / 8))] img_fig = config.writer.tensor2figure(tb_image, slice_indices, colorslist=config.colorslist, is_label=False, fig_title='image') # config.writer.add_figure('figure/{}_batch_seg_temp_imgs_before_resize2modelpatch'.format(config_task.task), [img_fig], config.step) scale_factors = [ model_patch_size[i] / config_task.patch_size[i] for i in range(len(model_patch_size)) ] newShape = [ int(oldShape[i] * scale_factors[i]) for i in range(len(scale_factors)) ] imgs_list = [] for i in range(temp_imgs.shape[0]): imgs_list.append( skimage.transform.resize(temp_imgs[i], output_shape=tuple(newShape), order=3, mode='constant')) # bi-cubic. temp_imgs = np.asarray(imgs_list) # tb visualization # colorslist=['#000000','#00FF00','#0000FF','#FF0000', '#FFFF00'] tb_image = temp_imgs[0, ...] slice_indices = [8 * i for i in range(int(tb_image.shape[0] / 8))] img_fig = config.writer.tensor2figure(tb_image, slice_indices, colorslist=config.colorslist, is_label=False, fig_title='image') # config.writer.add_figure('figure/{}_batch_seg_temp_imgs_after_resize2modelpatch'.format(config_task.task), [img_fig], config.step) else: raise ValueError('{}: not yet implemented!!'.format(config.unifyPatch)) # logger.info('resize2modelpatch time elapsed:{}'.format(tinies.timer(st_time, time.time()))) data_channel, D, H, W = temp_imgs.shape temp_prob1 = np.zeros([D, H, W, num_class]) data_mini_batch = [] centers = [] st_time = time.time() for patch_center_W in range(int(model_patch_size[2] / 2), W + int(model_patch_size[2] / 2), int(model_patch_size[2] / 2)): patch_center_W = min(patch_center_W, W - int(model_patch_size[2] / 2)) for patch_center_H in range(int(model_patch_size[1] / 2), H + int(model_patch_size[1] / 2), int(model_patch_size[1] / 2)): patch_center_H = min(patch_center_H, H - int(model_patch_size[1] / 2)) for patch_center_D in range(int(model_patch_size[0] / 2), D + int(model_patch_size[0] / 2), int(model_patch_size[0] / 2)): patch_center_D = min(patch_center_D, D - int(model_patch_size[0] / 2)) temp_input_center = [ patch_center_D, patch_center_H, patch_center_W ] # logger.info("temp_input_center:{}".format(temp_input_center)) # ipdb.set_trace() centers.append(temp_input_center) patch = [] for chn in range(data_channel): sub_patch = extract_roi_from_volume(temp_imgs[chn], temp_input_center, model_patch_size, fill="zero") patch.append(sub_patch) patch = np.asanyarray(patch, np.float32) #[mod,d,h,w] # collect to batch data_mini_batch.append(patch) # [14,4,d,h,w] # 4, modalities; if len(data_mini_batch) == batch_size: data_mini_batch = np.asarray(data_mini_batch, np.float32) # data_mini_batch = np.transpose(data_mini_batch, [0, 2, 3, 4, 1]) # batch_size, d, h, w, modality # ipdb.set_trace() data_mini_batch = torch.from_numpy(data_mini_batch).float( ).cuda() # numpy to torch to GPU if config.trainMode == "universal": prob_mini_batch1, share_map, para_map = model( data_mini_batch) else: prob_mini_batch1 = model(data_mini_batch) # if config.test_flip: # prob_mini_batch1 += model(torch.flip(data_mini_batch, [4])) prob_mini_batch1 = prob_mini_batch1.detach() # prob_mini_batch1 = np.transpose(prob_mini_batch1, [0,2,3,4,1]) # n,d,h,w,c prob_mini_batch1 = prob_mini_batch1.permute( [0, 2, 3, 4, 1]) # n,d,h,w,c data_mini_batch = [] for batch_idx in range(prob_mini_batch1.shape[0]): sub_prob = prob_mini_batch1[batch_idx] for i in range(num_class): # sub_prob[...,i] = np.multiply(sub_prob[...,i], config.patch_weights) sub_prob[..., i] = torch.mul(sub_prob[..., i], patch_weights) sub_prob = sub_prob.cpu().numpy() temp_input_center = centers[batch_idx] for c in range(num_class): temp_prob1[..., c] = set_roi_to_volume( temp_prob1[..., c], temp_input_center, sub_prob[..., c]) centers = [] remainder_batch_size = len(data_mini_batch) if remainder_batch_size > 0 and remainder_batch_size < batch_size: # treat the remainder as an idependent batch as it's smaller than batch_size for idx in range(batch_size - len(data_mini_batch)): data_mini_batch.append(np.zeros( [data_channel] + model_patch_size)) # fill to full batch_size with zeros array data_mini_batch = np.asarray(data_mini_batch, np.float32) # data_mini_batch = np.transpose(data_mini_batch, [0, 2, 3, 4, 1]) # batch_size, d, h, w, modality data_mini_batch = torch.from_numpy( data_mini_batch).float().cuda() # numpy to torch to GPU if config.trainMode == "universal": prob_mini_batch1, share_map, para_map = model(data_mini_batch) else: prob_mini_batch1 = model(data_mini_batch) # if config.test_flip: # flip on w axis? # prob_mini_batch1 += model(torch.flip(data_mini_batch, [4])) prob_mini_batch1 = prob_mini_batch1.detach() # prob_mini_batch1 = np.transpose(prob_mini_batch1, [0,2,3,4,1]) prob_mini_batch1 = prob_mini_batch1.permute([0, 2, 3, 4, 1]) # n,d,h,w,c # logger.info('prob_mini_batch1 shape:{}'.format(prob_mini_batch1.shape)) data_mini_batch = [] for batch_idx in range(remainder_batch_size): sub_prob = prob_mini_batch1[batch_idx] # sub_prob = np.reshape(prob_mini_batch1[batch_idx], model_patch_size + [num_class]) for i in range(num_class): # sub_prob[...,i] = np.multiply(sub_prob[...,i], config.patch_weights) sub_prob[..., i] = torch.mul(sub_prob[..., i], patch_weights) sub_prob = sub_prob.cpu().numpy() temp_input_center = centers[batch_idx] for c in range(num_class): temp_prob1[..., c] = set_roi_to_volume(temp_prob1[..., c], temp_input_center, sub_prob[..., c]) elif remainder_batch_size >= batch_size: logger.error( 'the remainder data_mini_batch size is {} and batch_size = {}, code is wrong' .format(len(data_mini_batch), batch_size)) logger.info('patch eval for-loop time elapsed:{}'.format( tinies.timer(st_time, time.time()))) # argmax temp_pred1 = np.argmax(temp_prob1, axis=-1) # temp_pred1 = np.asarray(temp_pred1, dtype=np.uint8) if config.unifyPatch == 'resize': # resize all tasks images to same size for universal model if config.trainMode in ["shared", "universal"]: # tb visualization # colorslist=['#000000','#00FF00','#0000FF','#FF0000', '#FFFF00'] tb_image = temp_imgs[0, ...] tb_pred = temp_pred1 slice_indices = config.writer.chooseSlices(tb_pred) img_fig = config.writer.tensor2figure(tb_image, slice_indices, colorslist=config.colorslist, is_label=False, fig_title='image') pred_fig = config.writer.tensor2figure( tb_pred, slice_indices, colorslist=config.colorslist, is_label=True, fig_title='pred') # config.writer.add_figure('figure/{}_batch_seg_temp_pred1_before_resize2originalScale'.format(config_task.task), [img_fig, pred_fig], config.step) # reisze temp_pred1 = temp_pred1.astype( np.float32 ) # it will result in nothing if input an array of np.uint8 to resize(order=0) temp_pred1 = skimage.transform.resize(temp_pred1, output_shape=tuple(oldShape), order=0, mode='constant') # nearest. # tb visualization # colorslist=['#000000','#00FF00','#0000FF','#FF0000', '#FFFF00'] tb_image = temp_imgs[0, ...] tb_pred = temp_pred1 slice_indices = config.writer.chooseSlices(tb_pred) img_fig = config.writer.tensor2figure(tb_image, slice_indices, colorslist=config.colorslist, is_label=False, fig_title='image') pred_fig = config.writer.tensor2figure( tb_pred, slice_indices, colorslist=config.colorslist, is_label=True, fig_title='pred') # config.writer.add_figure('figure/{}_batch_seg_temp_pred1_after_resize2originalScale'.format(config_task.task), [img_fig, pred_fig], config.step) else: raise ValueError('{}: not yet implemented!!'.format(config.unifyPatch)) temp_pred1 = np.asarray(temp_pred1, dtype=np.uint8) # for some cases, e.g. Task04_Hippocampus. temp_imgs[0] shape is smaller than model_patch_size.. here use crop to recover to original shape. if np.any(pad_size): temp_pred1 = temp_pred1[pad_size[0]:(original_D + pad_size[0]), pad_size[1]:(original_H + pad_size[1]), pad_size[2]:(original_W + pad_size[2])] return temp_pred1
def evaluate(config_task, ids, model, outdir='eval_out', epoch_num=0): """ evalutation """ files = load_files(ids) files = list(files) datDir = os.path.join(config.prepData_dir, config_task.task, "Tr") dices_list = [] # files = files[:2] # debugging. logger.info('Evaluating epoch{} for {}--- {} cases:\n{}'.format( epoch_num, config_task.task, len(files), str([obj['id'] for obj in files]))) for obj in tqdm(files, desc='Eval epoch{}'.format(epoch_num)): ID = obj['id'] # logger.info('evaluating {}:'.format(ID)) obj['im'] = os.path.join(config.base_dir, config_task.task, "imagesTr", ID) obj['gt'] = os.path.join(config.base_dir, config_task.task, "labelsTr", ID) img_path = os.path.join(config.base_dir, config_task.task, "imagesTr", ID) gt_path = os.path.join(config.base_dir, config_task.task, "labelsTr", ID) data = get_eval_data(obj, datDir) # final_label, probs = segment_one_image(config_task, data, model) # final_label: d, h, w, num_classes try: final_label = segment_one_image( config_task, data, model, ID) # final_label: d, h, w, num_classes save_to_nii(final_label, filename=ID + '.nii.gz', refer_file_path=img_path, outdir=outdir, mode="label", prefix='Epoch{}_'.format(epoch_num)) gt = sitk.GetArrayFromImage(sitk.ReadImage(gt_path)) # d, h, w # treat cancer as organ for Task03_Liver and Task07_Pancreas if config_task.task in ['Task03_Liver', 'Task07_Pancreas']: gt[gt == 2] = 1 # cal dices dices = multiClassDice(gt, final_label, config_task.num_class) dices_list.append(dices) tinies.sureDir(outdir) fo = open(os.path.join(outdir, '{}_eval_res.csv'.format(config_task.task)), mode='a+') wo = csv.writer(fo, delimiter=',') wo.writerow([epoch_num, tinies.datestr(), ID] + dices) fo.flush() ## for tensorboard visualization tb_img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)) # d,h,w if tb_img.ndim == 4: tb_img = tb_img[0, ...] train.tb_images([tb_img, gt, final_label], [False, True, True], ['image', 'GT', 'PS'], epoch_num * config.step_per_epoch, tag='Eval_{}_epoch_{}_dices_{}'.format( ID, epoch_num, str(dices))) except Exception as e: logger.info('{}'.format(str(e))) labels = config_task.labels dices_all = np.asarray(dices_list) dices_mean = dices_all.mean(axis=0) logger.info('Eval mean dices:') dices_res = {} for i in range(config_task.num_class): tag = labels[str(i)] dices_res[tag] = dices_mean[i] logger.info(' {}, {}'.format(tag, dices_mean[i])) return dices_res