def test(self): self.batch_size = 1 test_images = tf.placeholder(dtype=tf.float32, shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH]) test_gt = tf.placeholder(dtype=tf.float32, shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH]) u_test, _, _, _, _, _, _, _, _ = self.PD_v2(test_images, reuse=tf.AUTO_REUSE, training=False) saver = tf.train.Saver() with tf.name_scope('All_Metrics'): dice_coe_test = dice_coe(output=u_test,target=test_gt) iou_coe_test = iou_coe(output=u_test,target=test_gt) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if self.restore: saver.restore(sess, self.checkpoint) t_f, v_f = self.create_list(data_path=self.data_dir) data_test = self.load_data_valid(data_path=self.data_dir, list=v_f) dice_test_list = [] iou_test_list = [] for test_iter in tqdm(range(self.num_test_samples)): test_i, test_m = next(data_test) dice_coe_test_value, iou_coe_test_value, = sess.run([dice_coe_test, iou_coe_test],feed_dict={test_images: test_i, test_gt: test_m}) dice_test_list.append(dice_coe_test_value) iou_test_list.append(iou_coe_test_value) if not os.path.exists('Test'): os.makedirs('Test') np.savez('Test/Test_'+self.model_name+'.npz', dice=dice_test_list, iou=iou_test_list, dice_avg=np.average(dice_test_list), iou_avg=np.average(iou_test_list)) data = np.load('Test/Test_'+self.model_name+'.npz') print(data['dice']) print(data['iou']) print(data['dice_avg']) print(data['iou_avg'])
def validation(trained_net, val_set, criterion, device, batch_size): ''' used for valuation during training phase params trained_net: trained U-net params val_set: validation dataset params criterion: loss function params device: cpu or gpu ''' n_val = len(val_set) val_loader = val_set.load() tot = 0 acc = 0 with tqdm(total=n_val, desc='Validation round', unit='patch', leave=False) as pbar: with torch.no_grad(): for i, sample in enumerate(val_loader): images, segs = sample['image'].to( device=device), sample['seg'].to(device=device) preds = trained_net(images) val_loss = criterion(preds, segs) dice_score = dice_coe(preds.detach().cpu(), segs.detach().cpu()) tot += val_loss.detach().item() acc += dice_score['avg'] pbar.set_postfix( **{ 'validation loss (images)': val_loss.detach().item(), 'val_acc': dice_score['avg'] }) pbar.update(images.shape[0]) return tot / (np.ceil(n_val / batch_size)), acc / (np.ceil( n_val / batch_size))
def predict(trained_net, model_name, test_patient, root, crop_size=98, overlap_size=0, save_mask=True): ''' used for predicting image segmentation after training ''' # load test image patient_name = os.path.basename(test_patient) modality_dir = os.listdir(test_patient) image = [] for modality in modality_dir: if modality != patient_name + '_seg.nii.gz': path = os.path.join(test_patient, modality) img = sitk.GetArrayFromImage(sitk.ReadImage(path)) image.append(img) else: path = os.path.join(test_patient, modality) seg = sitk.ReadImage(path) seg_arr = sitk.GetArrayFromImage(seg) image = np.stack(image) # C*D*H*W # model inference trained_net.eval() image_shape = image.shape[-3:] crop_info = crop_index_gen(image_shape=image_shape, crop_size=crop_size, overlap_size=overlap_size) image_patches = image_crop(image, crop_info, norm=True, ToTensor=True) cropped_image_list = np.zeros_like(image_patches.cpu().numpy()) with torch.no_grad(): with tqdm(total=len(cropped_image_list), desc='inference test image', unit='patch') as pbar: for i, image in enumerate(image_patches): image = image.unsqueeze(dim=0) preds = trained_net(image) cropped_image_list[i, ...] = preds.squeeze( 0).detach().cpu().numpy() pbar.update(1) crop_index = crop_info['index_array'] rebuild_four_channels = image_rebuild(crop_index, cropped_image_list) inferenced_mask = inference_output(rebuild_four_channels) # calcualte DSC target = torch.from_numpy(seg_arr).unsqueeze(0) pred = torch.from_numpy(inferenced_mask).unsqueeze(0) pred = to_one_hot(pred) dsc = dice_coe(pred, target) print('DSC by label of this image is: ', dsc) # plot predicted segmentation plt.figure(figsize=(20, 10)) ground_truth = seg_arr[image_shape[0] // 2] predicted = inferenced_mask[image_shape[0] // 2] image_list = [ground_truth, predicted] subtitles = ['ground truth', 'predicted'] plt.subplots_adjust(wspace=0.3) for i in range(1, 3): ax = plt.subplot(1, 2, i) ax.set_title(subtitles[i - 1]) sns.heatmap(image_list[i - 1], vmin=0, vmax=4, xticklabels=False, yticklabels=False, square=True, cmap='coolwarm', cbar=True) # save prediction save_path = os.path.join(root, 'prediction_results', patient_name) if not os.path.exists(save_path): os.makedirs(save_path) mask = sitk.GetImageFromArray(inferenced_mask.astype(np.int16)) if save_mask: plt.savefig( os.path.join( save_path, '{}_2D_prediction_{}.png'.format(patient_name, model_name))) mask.CopyInformation(seg) sitk.WriteImage( mask, os.path.join( save_path, '{}_2D_prediction_{}.nii.gz'.format(patient_name, model_name)))
def train(args): torch.cuda.manual_seed(1) torch.manual_seed(1) # user defined model_name = args.model_name model_loss_fn = args.loss_fn config_file = 'config.yaml' config = load_config(config_file) data_root = config['PATH']['data_root'] labels = config['PARAMETERS']['labels'] root_path = config['PATH']['root'] model_dir = config['PATH']['model_path'] best_dir = config['PATH']['best_model_path'] data_class = config['PATH']['data_class'] input_modalites = int(config['PARAMETERS']['input_modalites']) output_channels = int(config['PARAMETERS']['output_channels']) base_channel = int(config['PARAMETERS']['base_channels']) crop_size = int(config['PARAMETERS']['crop_size']) batch_size = int(config['PARAMETERS']['batch_size']) epochs = int(config['PARAMETERS']['epoch']) is_best = bool(config['PARAMETERS']['is_best']) is_resume = bool(config['PARAMETERS']['resume']) patience = int(config['PARAMETERS']['patience']) ignore_idx = int(config['PARAMETERS']['ignore_index']) early_stop_patience = int(config['PARAMETERS']['early_stop_patience']) # build up dirs model_path = os.path.join(root_path, model_dir) best_path = os.path.join(root_path, best_dir) intermidiate_data_save = os.path.join(root_path, 'train_data', model_name) train_info_file = os.path.join(intermidiate_data_save, '{}_train_info.json'.format(model_name)) log_path = os.path.join(root_path, 'logfiles') if not os.path.exists(model_path): os.mkdir(model_path) if not os.path.exists(best_path): os.mkdir(best_path) if not os.path.exists(intermidiate_data_save): os.makedirs(intermidiate_data_save) if not os.path.exists(log_path): os.mkdir(log_path) log_name = model_name + '_' + config['PATH']['log_file'] logger = logfile(os.path.join(log_path, log_name)) logger.info('Dataset is loading ...') # split dataset dir_ = os.path.join(data_root, data_class) data_content = train_split(dir_) # load training set and validation set train_set = data_loader(data_content=data_content, key='train', form='LGG', crop_size=crop_size, batch_size=batch_size, num_works=8) n_train = len(train_set) train_loader = train_set.load() val_set = data_loader(data_content=data_content, key='val', form='LGG', crop_size=crop_size, batch_size=batch_size, num_works=8) logger.info('Dataset loading finished!') n_val = len(val_set) nb_batches = np.ceil(n_train / batch_size) n_total = n_train + n_val logger.info( '{} images will be used in total, {} for trainning and {} for validation' .format(n_total, n_train, n_val)) net = init_U_Net(input_modalites, output_channels, base_channel) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' if torch.cuda.device_count() > 1: logger.info('{} GPUs available.'.format(torch.cuda.device_count())) net = nn.DataParallel(net) net.to(device) if model_loss_fn == 'Dice': criterion = DiceLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'CrossEntropy': criterion = CrossEntropyLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'FocalLoss': criterion = FocalLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'Dice_CE': criterion = Dice_CE(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'Dice_FL': criterion = Dice_FL(labels=labels, ignore_index=ignore_idx) else: raise NotImplementedError() optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-5) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) net, optimizer = amp.initialize(net, optimizer, opt_level='O1') min_loss = float('Inf') early_stop_count = 0 global_step = 0 start_epoch = 0 start_loss = 0 train_info = { 'train_loss': [], 'val_loss': [], 'BG_acc': [], 'NET_acc': [], 'ED_acc': [], 'ET_acc': [] } if is_resume: try: ckp_path = os.path.join(model_dir, '{}_model_ckp.pth.tar'.format(model_name)) net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp( ckp_path, net, optimizer, scheduler) # open previous training records with open(train_info_file) as f: train_info = json.load(f) logger.info( 'Training loss from last time is {}'.format(start_loss) + '\n' + 'Mininum training loss from last time is {}'.format(min_loss)) except: logger.warning( 'No checkpoint available, strat training from scratch') # start training for epoch in range(start_epoch, epochs): # setup to train mode net.train() running_loss = 0 dice_coeff_bg = 0 dice_coeff_net = 0 dice_coeff_ed = 0 dice_coeff_et = 0 logger.info('Training epoch {} will begin'.format(epoch + 1)) with tqdm(total=n_train, desc=f'Epoch {epoch+1}/{epochs}', unit='patch') as pbar: for i, data in enumerate(train_loader, 0): images, segs = data['image'].to(device), data['seg'].to(device) # zero the parameter gradients optimizer.zero_grad() outputs = net(images) loss = criterion(outputs, segs) # loss.backward() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() # save the output at the begining of each epoch to visulize it if i == 0: in_images = images.detach().cpu().numpy()[:, 0, ...] in_segs = segs.detach().cpu().numpy() in_pred = outputs.detach().cpu().numpy() heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=model_name, epoch=epoch + 1) running_loss += loss.detach().item() dice_score = dice_coe(outputs.detach().cpu(), segs.detach().cpu()) dice_coeff_bg += dice_score['BG'] dice_coeff_ed += dice_score['ED'] dice_coeff_et += dice_score['ET'] dice_coeff_net += dice_score['NET'] # show progress bar pbar.set_postfix( **{ 'Training loss': loss.detach().item(), 'Training (avg) accuracy': dice_score['avg'] }) pbar.update(images.shape[0]) global_step += 1 if global_step % nb_batches == 0: # validate net.eval() val_loss, val_acc = validation(net, val_set, criterion, device, batch_size) train_info['train_loss'].append(running_loss / nb_batches) train_info['val_loss'].append(val_loss) train_info['BG_acc'].append(dice_coeff_bg / nb_batches) train_info['NET_acc'].append(dice_coeff_net / nb_batches) train_info['ED_acc'].append(dice_coeff_ed / nb_batches) train_info['ET_acc'].append(dice_coeff_et / nb_batches) # save bast trained model scheduler.step(running_loss / nb_batches) if min_loss > val_loss: min_loss = val_loss is_best = True early_stop_count = 0 else: is_best = False early_stop_count += 1 state = { 'epoch': epoch + 1, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': running_loss / nb_batches, 'min_loss': min_loss } verbose = save_ckp(state, is_best, early_stop_count=early_stop_count, early_stop_patience=early_stop_patience, save_model_dir=model_path, best_dir=best_path, name=model_name) logger.info('The average training loss for this epoch is {}'.format( running_loss / (np.ceil(n_train / batch_size)))) logger.info( 'Validation dice loss: {}; Validation (avg) accuracy: {}'.format( val_loss, val_acc)) logger.info('The best validation loss till now is {}'.format(min_loss)) # save the training info every epoch logger.info('Writing the training info into file ...') with open(train_info_file, 'w') as fp: json.dump(train_info, fp) loss_plot(train_info_file, name=model_name) if verbose: logger.info( 'The validation loss has not improved for {} epochs, training will stop here.' .format(early_stop_patience)) break logger.info('finish training!')
def train(self): is_random = tf.placeholder(dtype=tf.bool,name='is_random') train_images = tf.placeholder(dtype=tf.float32,shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH]) validation_images = tf.placeholder(dtype=tf.float32,shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH]) test_images = tf.placeholder(dtype=tf.float32,shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH]) train_gt = tf.placeholder(dtype=tf.float32,shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH]) validation_gt = tf.placeholder(dtype=tf.float32,shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH]) test_gt = tf.placeholder(dtype=tf.float32,shape=[self.batch_size, self.IM_ROWS, self.IM_COLS, self.IM_DEPTH]) train_pack = tf.concat([train_images,train_gt],axis=-1) train_pack = tf.cond(is_random,lambda:tf.image.rot90(train_pack), lambda: train_pack) train_pack = tf.image.random_flip_up_down(train_pack) train_pack = tf.image.random_flip_left_right(train_pack) # train_pack = tf.image.random_crop(train_pack) train_images_da = tf.expand_dims(train_pack[:,:,:,0],-1) train_gt_da = tf.expand_dims(train_pack[:,:,:,1],-1) u_train, td, tp, lda, alpha, delta, sigma, f, th = self.PD_v2(train_images_da, training = True) u_validation, _, _, _, _, delta_v, _, _, th_v= self.PD_v2(validation_images, reuse=True, training= False) u_test, _, _, _, _, _, _, _, _= self.PD_v2(test_images, reuse=True, training= False) with tf.variable_scope('Regularization'): TV2D = tf.reduce_mean(self.TV2D(th)) with tf.name_scope('Bayesian_Weights'): s1 = tf.Variable(tf.sqrt(0.5),name='s1') s2 = tf.Variable(0.1,name='s2') c1 = tf.div(1.,2.*(tf.square(s1)),name='c1') c2 = tf.div(1.,2.*(tf.square(s2)),name='c2') # train_gt_flat = tf.reshape(train_gt,[self.batch_size,-1]) # u_train_flat = tf.reshape(u_train,[self.batch_size,-1]) # train_gt_flat = tf.concat([train_gt_flat,1.0-train_gt_flat],axis=-1) # u_train_flat = tf.concat([u_train_flat,1.0-u_train_flat],axis=-1) #train_loss = tf.losses.softmax_cross_entropy(onehot_labels=train_gt_flat,logits=u_train_flat) # loss_l2 = tf.losses.mean_squared_error(labels=train_gt,predictions=u_train) # u_train = delta loss_l2 = tf.losses.mean_squared_error(labels=train_gt_da,predictions=u_train) U2 = tf.reduce_mean(tf.square(delta)) train_loss = c1*tf.reduce_mean(1.-dice_coe(output=u_train,target=train_gt_da)) + s1 + 0.001*U2 # train_loss = c1*loss_l2 + s1 + 0.001*U2 # u_train = tf.maximum(0.0, tf.minimum(1., u_train)) # validation_loss = tf.losses.mean_squared_error(labels=validation_gt,predictions=u_validation) # u_validation = tf.maximum(0.0, tf.minimum(1., u_validation)) validation_loss = 1.-dice_coe(output=u_validation,target=validation_gt) test_loss = 1.-dice_coe(output=u_test,target=test_gt) global_step = tf.Variable(0, trainable=False) var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) optimizer = optimizer_opt(self,loss=train_loss,var_list=var_list,global_step=global_step,show_gradients=True) saver = tf.train.Saver() with tf.name_scope('All_Metrics'): dice_coe_t = dice_coe(output=u_train,target=train_gt) iou_coe_t = iou_coe(output=u_train,target=train_gt) dice_coe_v = dice_coe(output=u_validation,target=validation_gt) iou_coe_v = iou_coe(output=u_validation,target=validation_gt) dice_coe_test = dice_coe(output=u_test,target=test_gt) iou_coe_test = iou_coe(output=u_test,target=test_gt) # Summaries random_slice = tf.placeholder(dtype=tf.int32) with tf.name_scope('Training'): with tf.name_scope('Bayesian_weights'): tf.summary.scalar('c1',c1) # tf.summary.scalar('c2',c2) with tf.name_scope('Metrics'): tf.summary.scalar('Dice_coe',dice_coe_t) tf.summary.scalar('Iou_coe',iou_coe_t) with tf.name_scope('Images'): tf.summary.image('Input', train_images_da) out = u_train out = tf.cast(tf.reshape(out, (self.batch_size, self.IM_ROWS, self.IM_COLS, 1)), tf.float32) tf.summary.image('Output', out) sum_mask = train_gt_da sum_mask = tf.cast(tf.reshape(sum_mask, (self.batch_size, self.IM_ROWS, self.IM_COLS, 1)), tf.float32) tf.summary.image('GT', sum_mask) tf.summary.image('delta', delta) tf.summary.image('th', th) with tf.name_scope('Losses'): tf.summary.scalar('Loss',train_loss) tf.summary.scalar('TV2D_Reg',TV2D) tf.summary.scalar('U2_Reg',U2) with tf.name_scope('Parameters'): tf.summary.scalar('tp', tp) tf.summary.scalar('td', td) tf.summary.scalar('Lambda', lda[0,0,0,0]) tf.summary.scalar('alpha', alpha[0,0,0,0]) tf.summary.scalar('sigma', sigma) summary_train = tf.summary.merge_all() with tf.name_scope('Validation'): with tf.name_scope('Losses'): vl = tf.summary.scalar('Loss',validation_loss) with tf.name_scope('Metrics'): dv = tf.summary.scalar('Dice_coe',dice_coe_v) iv = tf.summary.scalar('Iou_coe',iou_coe_v) with tf.name_scope('Images'): inv = tf.summary.image('Input', validation_images) out = u_validation out = tf.cast(tf.reshape(out, (self.batch_size, self.IM_ROWS, self.IM_COLS, 1)), tf.float32) outv = tf.summary.image('Output', out) sum_mask = validation_gt sum_mask = tf.cast(tf.reshape(sum_mask, (self.batch_size, self.IM_ROWS, self.IM_COLS, 1)), tf.float32) gtv = tf.summary.image('GT', sum_mask) delv = tf.summary.image('delta', delta_v) thv = tf.summary.image('th', th_v) summary_validation = tf.summary.merge(inputs=[vl, iv, dv, outv, gtv, inv, delv, thv]) with tf.name_scope('Test'): dice_test = tf.placeholder(dtype=tf.float32) iou_test = tf.placeholder(dtype=tf.float32) with tf.name_scope('Metrics'): d_test = tf.summary.scalar('Dice_coe',dice_test) i_test = tf.summary.scalar('Iou_coe',iou_test) summary_test = tf.summary.merge(inputs=[d_test, i_test]) run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() with tf.Session() as sess: summary_writer_t = tf.summary.FileWriter(self.log_dir+'/train', sess.graph,filename_suffix='train') summary_writer_v = tf.summary.FileWriter(self.log_dir+'/validation', sess.graph,filename_suffix='valid') summary_writer_test = tf.summary.FileWriter(self.log_dir+'/test', sess.graph,filename_suffix='test') sess.run(tf.global_variables_initializer()) if self.restore: saver.restore(sess, self.checkpoint) counter = np.load('counter.npy') epocas = np.load('epocas.npy') else: counter = 0 epocas = 0 t_f, v_f = self.create_list(data_path=self.data_dir) data_train = self.load_data_train(data_path=self.data_dir, list=t_f) data_valid = self.load_data_valid(data_path=self.data_dir, list=v_f) # data_test = self.load_data_test(data_path=self.data_dir, list=v_f) for i in tqdm(range(self.training_iters-counter)): if (counter % self.epoch_iteration)==0 or counter==0: print('Shuffle Data') t_f = self.shuffle_data(t_f) v_f = self.shuffle_data(v_f) data_train = self.load_data_train(data_path=self.data_dir, list=t_f) data_valid = self.load_data_valid(data_path=self.data_dir, list=v_f) if not((counter % (self.epoch_iteration/10))==0): ti, tm = next(data_train) _= sess.run([optimizer],feed_dict={train_images: ti, train_gt: tm, is_random: np.random.randint(0,2)}) else: ti, tm = next(data_train) vi, vm = next(data_valid) _, summary_str_t, summary_str_v, global_step_value = sess.run([optimizer, summary_train, summary_validation, global_step],feed_dict={train_images: ti, train_gt: tm,validation_images: vi, validation_gt: vm , is_random: np.random.randint(0,2)}, options = run_options, run_metadata = run_metadata) summary_writer_t.add_run_metadata(run_metadata, 'step%d' % counter) summary_writer_t.add_summary(summary_str_t, global_step_value) summary_writer_t.flush() summary_writer_v.add_run_metadata(run_metadata, 'step%d' % counter) summary_writer_v.add_summary(summary_str_v, global_step_value) summary_writer_v.flush() if (counter % self.epoch_iteration)==0 and not(counter==0): epocas = epocas + 1 saver.save(sess, self.checkpoint) np.save('counter.npy',counter) np.save('epocas.npy',epocas) if self.is_test: if (counter % (self.test_each_epoch*self.epoch_iteration))==0: # and not(counter==0): dice_test_list = [] iou_test_list = [] test_list = np.load('./test_list.npy') num_images = len(test_list[0]) list_flair = test_list[0] list_gt = test_list[1] for subject in tqdm(list_flair): # vol_value = [] # vol_gt = [] print ('un volumen') for test_iter in range(155): test_i = np.load(self.data_dir+'validation'+subject+str(test_iter)+'.npy') test_i = np.reshape(test_i,newshape=[1,self.IM_ROWS,self.IM_COLS,1]) test_m = np.load(self.data_dir+'validation'+'/gt'+subject[3:]+str(test_iter)+'.npy') test_m = np.reshape(test_m,newshape=[1,self.IM_ROWS,self.IM_COLS,1]) slice_value = sess.run([u_test], feed_dict={test_images: test_i, test_gt: test_m}, options=run_options, run_metadata=run_metadata) if test_iter == 0: vol_value = slice_value vol_gt = test_m else: vol_value = np.concatenate([vol_value,slice_value],axis=-1) vol_gt = np.concatenate([vol_gt,test_m],axis=-1) # vol_value.append(slice_value[0][0,:,:,:]) # vol_gt.append(test_m[0,:,:,:]) vol_value = np.reshape(vol_value,[self.IM_ROWS,self.IM_COLS,155]) vol_gt = np.reshape(vol_gt,[self.IM_ROWS,self.IM_COLS,155]) dice_coe_test_value = self.np_dice(vol_value,vol_gt) iou_coe_test_value = self.np_iou(vol_value,vol_gt) dice_test_list.append(dice_coe_test_value) iou_test_list.append(iou_coe_test_value) summary_str_test, global_step_value = sess.run([summary_test, global_step],feed_dict={dice_test: np.mean(dice_test_list), iou_test: np.mean(iou_test_list)}) summary_writer_test.add_run_metadata(run_metadata, 'step%d' % epocas) summary_writer_test.add_summary(summary_str_test, global_step_value) summary_writer_test.flush() counter += 1 sess.close()