def eval(args, tasks_archive, model, eval_epoch, iterations): tasks = args.tasks # list model.eval() for task_idx in range(len(tasks)): config.task_idx = task_idx # needed for u2net3d(). task = tasks[task_idx] config_task = config.config_tasks[task] st_time = time.time() # evaluating. # tensorboard visualization of eval embedded. dices = evaluate.evaluate(config_task, tasks_archive[task]['fold' + str(args.fold)]['val'], model, epoch_num=eval_epoch, outdir=config.eval_out_dir) fo = open(os.path.join(config.eval_out_dir, '{}_eval_res.csv'.format(args.trainMode)), mode='a+') wo = csv.writer(fo, delimiter=',') for k, v in dices.items(): config.writer.add_scalar('data/dices/{}_{}'.format(task, k), v, iterations) wo.writerow([ args.trainMode, task, eval_epoch, config.step_per_epoch, k, v, tinies.datestr() ]) fo.flush() logger.info('Eval time elapsed:{}'.format( tinies.timer(st_time, time.time())))
def prep(files, outDir, with_gt=True): print("ids[0]:{}, current time:{}".format( os.path.basename(files[0]), str(tinies.datestr()))) for img_path in files: # tinies.ForkedPdb().set_trace() ID = os.path.basename(img_path).split('.')[0] if with_gt: lab_path = os.path.join(config.base_dir, task, 'labelsTr', ID) else: lab_path = None volume_list, label, weight, original_shape, [ bbmin, bbmax ] = utils.preprocess(img_path, lab_path, config_task, with_gt=with_gt) volumes = np.asarray(volume_list) np.save(os.path.join(outDir, ID + '_volumes.npy'), volumes) if with_gt: np.save(os.path.join(outDir, ID + '_label.npy'), label) np.save(os.path.join(outDir, ID + '_weight.npy'), weight) json_info = dict() json_info['original_shape'] = str( original_shape) # use eval() to unstr json_info['bbox'] = str([bbmin, bbmax]) # use eval() to unstr with open(os.path.join(outDir, ID + '.json'), 'w') as f: json.dump(json_info, f, indent=4)
def fuse(files, outDir, with_gt=True): print("ids[0]:{}, current time:{}".format( os.path.basename(files[0]), str(tinies.datestr()))) for lab_path in files: print('loading:{}'.format(lab_path)) # tinies.ForkedPdb().set_trace() label = np.load(lab_path) label[label == 2] = 1 # cancer fused to organ np.save(os.path.join(lab_path), label)
def gen_batch(self, batch_size, patch_size): batchImg = np.zeros([ batch_size, self.config_task.num_modality, patch_size[0], patch_size[1], patch_size[2] ]) # n,mod,d,h,w batchLabel = np.zeros( [batch_size, patch_size[0], patch_size[1], patch_size[2]]) # n,d,h,w batchWeight = np.zeros( [batch_size, patch_size[0], patch_size[1], patch_size[2]]) # n,d,h,w batchAugs = list() # import ipdb; ipdb.set_trace() for i in range(batch_size): temp_prob = np.random.uniform() st_time = time.time() handler = 0 while handler == 0: t_wait = 0 if self.trainQueue.qsize() == 0: logger.info( '{} self.trainQueue size = {}, filling....(start time:{})' .format(self.task, self.trainQueue.qsize(), tinies.datestr())) while self.trainQueue.qsize() == 0: time.sleep(1) t_wait += 1 if t_wait > 0: logger.info('{} time to fill self.trainQueue: {}'.format( self.task, t_wait)) patches = self.trainQueue.get() # logger.info('{} trainQueue size:{}'.format(self.task, str(self.trainQueue.qsize()))) if i <= math.ceil( batch_size / 3 ): # nn_unet3d: at least 1/3 samples in a batch contain at least one forground class if temp_prob < self.config_task.small_prob and patches[ 'small'] is not None: patch = patches['small'] handler = 1 elif patches['fore'] is not None: patch = patches['fore'] handler = 1 else: handler = 0 logger.warn('handler={}'.format(handler)) # else for i > math.ceil(batch_size/3) else: if temp_prob < self.config_task.small_prob and patches[ 'small'] is not None: patch = patches['small'] handler = 1 elif 1 - temp_prob < self.config_task.fore_prob and patches[ 'fore'] is not None: patch = patches['fore'] handler = 1 else: patch = patches['any'] handler = 1 if handler == 0: logger.info('handler is 0, going back') if handler == 0: logger.error('handler is 0') # fill in a batch batchImg[i, ...] = patch['image'] batchLabel[i, ...] = patch['label'] batchWeight[i, ...] = patch['weight'] batchAugs.append(patch['augs']) return (batchImg, batchLabel, batchWeight, batchAugs)
def train(args, tasks_archive, model): torch.backends.cudnn.benchmark = True if args.resume_ckp != '': logger.info('==> loading checkpoint: {}'.format(args.ckp)) checkpoint = torch.load(args.resume_ckp) model = nn.parallel.DataParallel(model) logger.info(' + model num_params: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) if config.use_gpu: model.cuda() # required bofore optimizer? # cudnn.benchmark = True print(model) # especially useful for debugging model structure. # summary(model, input_size=tuple([config.num_modality]+config.patch_size)) # takes some time. comment during debugging. ouput each layer's out shape. # for name, m in model.named_modules(): # logger.info('module name:{}'.format(name)) # print(m) # lr lr = config.base_lr if args.resume_ckp != '': optimizer = checkpoint['optimizer'] else: optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.weight_decay) # # loss dice_loss = MulticlassDiceLoss() ce_loss = nn.CrossEntropyLoss() focal_loss = FocalLoss(gamma=2) # prep data tasks = args.tasks # list tb_loaders = list() # train batch loader len_loader = list() for task in tasks: tb_loader = tb_load(task) tb_loader.enQueue(tasks_archive[task]['fold' + str(args.fold)], config.patch_size) tb_loaders.append(tb_loader) len_loader.append(len(tb_loader)) min_len_loader = np.min(len_loader) # init train values if args.resume_ckp != '': trLoss_queue = checkpoint['trLoss_queue'] last_trLoss_ma = checkpoint['last_trLoss_ma'] else: trLoss_queue = deque( maxlen=config.trLoss_win ) # queue to store exponential moving average of total loss in last N epochs last_trLoss_ma = None # the previous one. trLoss_queue_list = [ deque(maxlen=config.trLoss_win) for i in range(len(tasks)) ] last_trLoss_ma_list = [None] * len(tasks) trLoss_ma_list = [None] * len(tasks) if args.resume_epoch > 0: start_epoch = args.resume_epoch + 1 iterations = args.resume_epoch * config.step_per_epoch + 1 else: start_epoch = 1 iterations = 1 logger.info('start epoch: {}'.format(start_epoch)) ## run train for epoch in range(start_epoch, config.max_epoch + 1): logger.info(' ----- training epoch {} -----'.format(epoch)) epoch_st_time = time.time() model.train() loss_epoch = 0.0 loss_epoch_list = [0] * len(tasks) num_batch_processed = 0 # growing num_batch_processed_list = [0] * len(tasks) for step in tqdm(range(config.step_per_epoch), desc='{}: epoch{}'.format(args.trainMode, epoch)): config.step = iterations config.task_idx = (iterations - 1) % len(tasks) config.task = tasks[config.task_idx] # import ipdb; ipdb.set_trace() # tb show lr config.writer.add_scalar('data/lr', lr, iterations - 1) st_time = time.time() for idx in range(len(tasks)): tb_loaders[idx].check_process() # import ipdb; ipdb.set_trace() (batchImg, batchLabel, batchWeight, batchAugs) = tb_loaders[config.task_idx].gen_batch( config.batch_size, config.patch_size) # logger.info('idx{}_{}, gen_batch time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time()))) st_time = time.time() batchImg = torch.from_numpy(batchImg).float( ) # change all inputs to same torch tensor type batchLabel = torch.from_numpy(batchLabel).float() batchWeight = torch.from_numpy(batchWeight).float() if config.use_gpu: batchImg = batchImg.cuda() batchLabel = batchLabel.cuda() batchWeight = batchWeight.cuda() # logger.info('idx{}_{}, .cuda time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time()))) optimizer.zero_grad() st_time = time.time() if config.trainMode in ["universal"]: output, share_map, para_map = model(batchImg) else: output = model(batchImg) # logger.info('idx{}_{}, model() time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time()))) st_time = time.time() # tensorboard visualization of training for i in range(len(tasks)): if iterations > 200 and iterations % 1000 == i: tb_images([ batchImg[0, 0, ...], batchLabel[0, ...], torch.argmax(output[0, ...], dim=0) ], [False, True, True], ['image', 'GT', 'PS'], iterations, tag='Train_idx{}_{}_batch{}_{}'.format( config.task_idx, config.task, 0, '_'.join(batchAugs[0]))) tb_images([ batchImg[config.batch_size - 1, 0, ...], batchLabel[config.batch_size - 1, ...], torch.argmax(output[config.batch_size - 1, ...], dim=0) ], [False, True, True], ['image', 'GT', 'PS'], iterations, tag='Train_idx{}_{}_batch{}_{}_step{}'.format( config.task_idx, config.task, config.batch_size - 1, '_'.join(batchAugs[config.batch_size - 1]), iterations - 1)) if config.trainMode == "universal": logger.info( 'share_map shape:{}, para_map shape:{}'.format( str(share_map.shape), str(para_map.shape))) tb_images([ para_map[0, :, 64, ...], share_map[0, :, 64, ...] ], [False, False], ['last_para_map', 'last_share_map'], iterations, tag='Train_idx{}_{}_para_share_maps_channels' .format(config.task_idx, config.task)) logger.info( '----- {}, train epoch {} time elapsed:{} -----'.format( config.task, epoch, tinies.timer(epoch_st_time, time.time()))) st_time = time.time() output_softmax = F.softmax(output, dim=1) loss = lovasz_softmax(output_softmax, batchLabel, ignore=10) + focal_loss(output, batchLabel) loss.backward() optimizer.step() # logger.info('idx{}_{}, backward time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time()))) # loss.data.item() config.writer.add_scalar('data/loss_step', loss.item(), iterations) config.writer.add_scalar( 'data/loss_step_idx{}_{}'.format(config.task_idx, config.task), loss.item(), iterations) loss_epoch += loss.item() num_batch_processed += 1 loss_epoch_list[config.task_idx] += loss.item() num_batch_processed_list[config.task_idx] += 1 iterations += 1 # import ipdb; ipdb.set_trace() if epoch % config.save_epoch == 0: ckp_path = os.path.join( config.log_dir, '{}_{}_epoch{}_{}.pth.tar'.format(args.trainMode, '_'.join(args.tasks), epoch, tinies.datestr())) torch.save( { 'epoch': epoch, 'model': model, 'model_state_dict': model.state_dict(), 'optimizer': optimizer, 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'trLoss_queue': trLoss_queue, 'last_trLoss_ma': last_trLoss_ma }, ckp_path) loss_epoch /= num_batch_processed config.writer.add_scalar('data/loss_epoch', loss_epoch, iterations - 1) for idx in range(len(tasks)): task = tasks[idx] loss_epoch_list[idx] /= num_batch_processed_list[idx] config.writer.add_scalar( 'data/loss_epoch_idx{}_{}'.format(idx, task), loss_epoch_list[idx], iterations - 1) # import ipdb; ipdb.set_trace() ### lr decay trLoss_queue.append(loss_epoch) trLoss_ma = np.asarray(trLoss_queue).mean( ) # moving average. What about exponential moving average config.writer.add_scalar('data/trLoss_ma', trLoss_ma, iterations - 1) for idx in range(len(tasks)): task = tasks[idx] trLoss_queue_list[idx].append(loss_epoch_list[idx]) trLoss_ma_list[idx] = np.asarray(trLoss_queue_list[idx]).mean( ) # moving average. What about exponential moving average config.writer.add_scalar( 'data/trLoss_ma_idx{}_{}'.format(idx, task), trLoss_ma_list[idx], iterations - 1) # import ipdb; ipdb.set_trace() #### online eval Eval_bool = False if epoch >= config.start_val_epoch and epoch % config.val_epoch == 0: Eval_bool = True elif lr < 1e-8: Eval_bool = True logger.info( 'lr is reduced to {}. Will do the last evaluation for all samples!' .format(lr)) else: pass # if epoch >= config.start_val_epoch and epoch % config.val_epoch == 0: if Eval_bool: eval(args, tasks_archive, model, epoch, iterations - 1) ## stop if lr is too low if lr < 1e-8: logger.info('lr is reduced to {}. Job Done!'.format(lr)) break ###### lr decay based on current task if len(trLoss_queue) == trLoss_queue.maxlen: if last_trLoss_ma and last_trLoss_ma - trLoss_ma < 1e-4: # 5e-3 lr /= 2 for param_group in optimizer.param_groups: param_group['lr'] = lr last_trLoss_ma = trLoss_ma ## save model when lr < 1e-8 if lr < 1e-8: ckp_path = os.path.join( config.log_dir, '{}_{}_epoch{}_{}.pth.tar'.format(args.trainMode, '_'.join(args.tasks), epoch, tinies.datestr())) torch.save( { 'epoch': epoch, 'model': model, 'model_state_dict': model.state_dict(), 'optimizer': optimizer, 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'trLoss_queue': trLoss_queue, 'last_trLoss_ma': last_trLoss_ma }, ckp_path)
def evaluate(config_task, ids, model, outdir='eval_out', epoch_num=0): """ evalutation """ files = load_files(ids) files = list(files) datDir = os.path.join(config.prepData_dir, config_task.task, "Tr") dices_list = [] # files = files[:2] # debugging. logger.info('Evaluating epoch{} for {}--- {} cases:\n{}'.format( epoch_num, config_task.task, len(files), str([obj['id'] for obj in files]))) for obj in tqdm(files, desc='Eval epoch{}'.format(epoch_num)): ID = obj['id'] # logger.info('evaluating {}:'.format(ID)) obj['im'] = os.path.join(config.base_dir, config_task.task, "imagesTr", ID) obj['gt'] = os.path.join(config.base_dir, config_task.task, "labelsTr", ID) img_path = os.path.join(config.base_dir, config_task.task, "imagesTr", ID) gt_path = os.path.join(config.base_dir, config_task.task, "labelsTr", ID) data = get_eval_data(obj, datDir) # final_label, probs = segment_one_image(config_task, data, model) # final_label: d, h, w, num_classes try: final_label = segment_one_image( config_task, data, model, ID) # final_label: d, h, w, num_classes save_to_nii(final_label, filename=ID + '.nii.gz', refer_file_path=img_path, outdir=outdir, mode="label", prefix='Epoch{}_'.format(epoch_num)) gt = sitk.GetArrayFromImage(sitk.ReadImage(gt_path)) # d, h, w # treat cancer as organ for Task03_Liver and Task07_Pancreas if config_task.task in ['Task03_Liver', 'Task07_Pancreas']: gt[gt == 2] = 1 # cal dices dices = multiClassDice(gt, final_label, config_task.num_class) dices_list.append(dices) tinies.sureDir(outdir) fo = open(os.path.join(outdir, '{}_eval_res.csv'.format(config_task.task)), mode='a+') wo = csv.writer(fo, delimiter=',') wo.writerow([epoch_num, tinies.datestr(), ID] + dices) fo.flush() ## for tensorboard visualization tb_img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)) # d,h,w if tb_img.ndim == 4: tb_img = tb_img[0, ...] train.tb_images([tb_img, gt, final_label], [False, True, True], ['image', 'GT', 'PS'], epoch_num * config.step_per_epoch, tag='Eval_{}_epoch_{}_dices_{}'.format( ID, epoch_num, str(dices))) except Exception as e: logger.info('{}'.format(str(e))) labels = config_task.labels dices_all = np.asarray(dices_list) dices_mean = dices_all.mean(axis=0) logger.info('Eval mean dices:') dices_res = {} for i in range(config_task.num_class): tag = labels[str(i)] dices_res[tag] = dices_mean[i] logger.info(' {}, {}'.format(tag, dices_mean[i])) return dices_res