def diceLoss(y_true, y_pred): # Loss function for dice coefficient. Negative so we can minimize return -dice(y_true, y_pred)
def metric(self, logit, truth): """Define metrics for evaluation especially for early stoppping.""" #return iou_pytorch(logit, truth) return dice(logit, truth)
def validation(trained_net, val_set, criterion, device, batch_size, ignore_idx, name=None, epoch=None): n_val = len(val_set) val_loader = val_set.load() tot = 0 acc = 0 dice_score_bg = 0 dice_score_wm = 0 dice_score_gm = 0 dice_score_csf = 0 dice_score_tm = 0 val_info = {'bg': [], 'wm': [], 'gm': [], 'csf': [], 'tm': []} with tqdm(total=n_val, desc='Validation round', unit='patch', leave=False) as pbar: with torch.no_grad(): for i, sample in enumerate(val_loader): images, segs = sample['image'].to( device=device), sample['seg'].to(device=device) outputs = trained_net(images) val_loss = criterion(outputs, segs) if i == 0: in_images = images.detach().cpu().numpy()[0] in_segs = segs.detach().cpu().numpy()[0] in_pred = outputs.detach().cpu().numpy()[0] heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=name, epoch=epoch, is_train=False) outputs = outputs.view(-1, outputs.shape[-4], outputs.shape[-3], outputs.shape[-2], outputs.shape[-1]) segs = segs.view(-1, segs.shape[-3], segs.shape[-2], segs.shape[-1]) _, preds = torch.max(outputs.data, 1) dice_score = dice(preds.data.cpu(), segs.data.cpu(), ignore_idx=ignore_idx) dice_score_bg += dice_score['bg'] dice_score_wm += dice_score['wm'] dice_score_gm += dice_score['gm'] dice_score_csf += dice_score['csf'] dice_score_tm += dice_score['tm'] tot += val_loss.detach().item() acc += dice_score['avg'] pbar.set_postfix( **{ 'validation loss (images)': val_loss.detach().item(), 'val_acc_avg': dice_score['avg'] }) pbar.update(images.shape[0]) val_info['bg'] = dice_score_bg / (np.ceil(n_val / batch_size)) val_info['wm'] = dice_score_wm / (np.ceil(n_val / batch_size)) val_info['gm'] = dice_score_gm / (np.ceil(n_val / batch_size)) val_info['csf'] = dice_score_csf / (np.ceil(n_val / batch_size)) val_info['tm'] = dice_score_tm / (np.ceil(n_val / batch_size)) return tot / (np.ceil(n_val / batch_size)), acc / (np.ceil( n_val / batch_size)), val_info
def dice_loss(y_pred, y, c_weights=None): return 1 - dice(y_pred, y, c_weights)
def train(args): torch.cuda.manual_seed(1) torch.manual_seed(1) # user defined parameters model_name = args.model_name model_type = args.model_type lstm_backbone = args.lstmbase unet_backbone = args.unetbase layer_num = args.layer_num nb_shortcut = args.nb_shortcut loss_fn = args.loss_fn world_size = args.world_size rank = args.rank base_channel = args.base_channels crop_size = args.crop_size ignore_idx = args.ignore_idx return_sequence = args.return_sequence variant = args.LSTM_variant epochs = args.epoch is_pretrain = args.is_pretrain # system setup parameters config_file = 'config.yaml' config = load_config(config_file) labels = config['PARAMETERS']['labels'] root_path = config['PATH']['model_root'] model_dir = config['PATH']['save_ckp'] best_dir = config['PATH']['save_best_model'] input_modalites = int(config['PARAMETERS']['input_modalites']) output_channels = int(config['PARAMETERS']['output_channels']) batch_size = int(config['PARAMETERS']['batch_size']) is_best = bool(config['PARAMETERS']['is_best']) is_resume = bool(config['PARAMETERS']['resume']) patience = int(config['PARAMETERS']['patience']) time_step = int(config['PARAMETERS']['time_step']) num_workers = int(config['PARAMETERS']['num_workers']) early_stop_patience = int(config['PARAMETERS']['early_stop_patience']) lr = int(config['PARAMETERS']['lr']) optimizer = config['PARAMETERS']['optimizer'] connect = config['PARAMETERS']['connect'] conv_type = config['PARAMETERS']['lstm_convtype'] # build up dirs model_path = os.path.join(root_path, model_dir) best_path = os.path.join(root_path, best_dir) intermidiate_data_save = os.path.join(root_path, 'train_newdata', model_name) train_info_file = os.path.join(intermidiate_data_save, '{}_train_info.json'.format(model_name)) log_path = os.path.join(root_path, 'logfiles') if not os.path.exists(model_path): os.mkdir(model_path) if not os.path.exists(best_path): os.mkdir(best_path) if not os.path.exists(intermidiate_data_save): os.makedirs(intermidiate_data_save) if not os.path.exists(log_path): os.mkdir(log_path) log_name = model_name + '_' + config['PATH']['log_file'] logger = logfile(os.path.join(log_path, log_name)) logger.info('labels {} are ignored'.format(ignore_idx)) logger.info('Dataset is loading ...') writer = SummaryWriter('ProcessVisu/%s' % model_name) # load training set and validation set data_class = data_split() train, val, test = data_construction(data_class) train_dict = time_parser(train, time_patch=time_step) val_dict = time_parser(val, time_patch=time_step) # LSTM initilization if model_type == 'LSTM': net = LSTMSegNet(lstm_backbone=lstm_backbone, input_dim=input_modalites, output_dim=output_channels, hidden_dim=base_channel, kernel_size=3, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence) elif model_type == 'UNet_LSTM': if variant == 'back': net = BackLSTM(input_dim=input_modalites, hidden_dim=base_channel, output_dim=output_channels, kernel_size=3, num_layers=layer_num, conv_type=conv_type, lstm_backbone=lstm_backbone, unet_module=unet_backbone, base_channel=base_channel, return_sequence=return_sequence, is_pretrain=is_pretrain) logger.info( 'the pretrained status of backbone is {}'.format(is_pretrain)) elif variant == 'center': net = CenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'bicenter': net = BiCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'directcenter': net = DirectCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'bidirectcenter': net = BiDirectCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'rescenter': net = ResCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'birescenter': net = BiResCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'shortcut': net = ShortcutLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, num_connects=nb_shortcut, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) else: raise NotImplementedError() # loss and optimizer setup if loss_fn == 'Dice': criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx) elif loss_fn == 'GDice': criterion = GneralizedDiceLoss(labels=labels) elif loss_fn == 'WCE': criterion = WeightedCrossEntropyLoss(labels=labels) else: raise NotImplementedError() if optimizer == 'adam': optimizer = optim.Adam(net.parameters(), lr=0.001) # optimizer = optim.Adam(net.parameters()) elif optimizer == 'sgd': optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) # device setup device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # net, optimizer = amp.initialize(net, optimizer, opt_level="O1") if torch.cuda.device_count() > 1: torch.distributed.init_process_group( backend='nccl', init_method='tcp://127.0.0.1:38366', rank=rank, world_size=world_size) if distributed_is_initialized(): print('distributed is initialized') net.to(device) net = nn.parallel.DistributedDataParallel(net, find_unused_parameters=True) else: print('data parallel') net = nn.DataParallel(net) net.to(device) min_loss = float('Inf') early_stop_count = 0 global_step = 0 start_epoch = 0 start_loss = 0 train_info = { 'train_loss': [], 'val_loss': [], 'label_0_acc': [], 'label_1_acc': [], 'label_2_acc': [], 'label_3_acc': [], 'label_4_acc': [] } if is_resume: try: # open previous check points ckp_path = os.path.join(model_path, '{}_model_ckp.pth.tar'.format(model_name)) net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp( ckp_path, net, optimizer, scheduler) # open previous training records with open(train_info_file) as f: train_info = json.load(f) logger.info( 'Training loss from last time is {}'.format(start_loss) + '\n' + 'Mininum training loss from last time is {}'.format(min_loss)) logger.info( 'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}' .format(train_info['label_0_acc'][-1], train_info['label_1_acc'][-1], train_info['label_2_acc'][-1], train_info['label_3_acc'][-1], train_info['label_4_acc'][-1])) except: logger.warning( 'No checkpoint available, strat training from scratch') for epoch in range(start_epoch, epochs): train_set = data_loader(train_dict, batch_size=batch_size, key='train', num_works=num_workers, time_step=time_step, patch=crop_size, model_type='RNN') n_train = len(train_set) val_set = data_loader(val_dict, batch_size=batch_size, key='val', num_works=num_workers, time_step=time_step, patch=crop_size, model_type='CNN') n_val = len(val_set) logger.info('Dataset loading finished!') nb_batches = np.ceil(n_train / batch_size) n_total = n_train + n_val logger.info( '{} images will be used in total, {} for trainning and {} for validation' .format(n_total, n_train, n_val)) train_loader = train_set.load() # setup to train mode net.train() running_loss = 0 dice_score_label_0 = 0 dice_score_label_1 = 0 dice_score_label_2 = 0 dice_score_label_3 = 0 dice_score_label_4 = 0 logger.info('Training epoch {} will begin'.format(epoch + 1)) with tqdm(total=n_train, desc=f'Epoch {epoch+1}/{epochs}', unit='patch') as pbar: for i, data in enumerate(train_loader, 0): # i : patient images, segs = data['image'].to(device), data['seg'].to(device) outputs = net(images) loss = criterion(outputs, segs) loss.backward() optimizer.step() # if i == 0: # in_images = images.detach().cpu().numpy()[0] # in_segs = segs.detach().cpu().numpy()[0] # in_pred = outputs.detach().cpu().numpy()[0] # heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=model_name, epoch=epoch+1, is_train=True) running_loss += loss.detach().item() outputs = outputs.view(-1, outputs.shape[-4], outputs.shape[-3], outputs.shape[-2], outputs.shape[-1]) segs = segs.view(-1, segs.shape[-3], segs.shape[-2], segs.shape[-1]) _, preds = torch.max(outputs.data, 1) dice_score = dice(preds.data.cpu(), segs.data.cpu(), ignore_idx=None) dice_score_label_0 += dice_score['bg'] dice_score_label_1 += dice_score['csf'] dice_score_label_2 += dice_score['gm'] dice_score_label_3 += dice_score['wm'] dice_score_label_4 += dice_score['tm'] # show progress bar pbar.set_postfix( **{ 'training loss': loss.detach().item(), 'Training accuracy': dice_score['avg'] }) pbar.update(images.shape[0]) global_step += 1 if global_step % nb_batches == 0: net.eval() val_loss, val_acc, val_info = validation(net, val_set, criterion, device, batch_size, ignore_idx=None, name=model_name, epoch=epoch + 1) net.train() train_info['train_loss'].append(running_loss / nb_batches) train_info['val_loss'].append(val_loss) train_info['label_0_acc'].append(dice_score_label_0 / nb_batches) train_info['label_1_acc'].append(dice_score_label_1 / nb_batches) train_info['label_2_acc'].append(dice_score_label_2 / nb_batches) train_info['label_3_acc'].append(dice_score_label_3 / nb_batches) train_info['label_4_acc'].append(dice_score_label_4 / nb_batches) # save bast trained model scheduler.step(running_loss / nb_batches) logger.info('Epoch: {}, LR: {}'.format( epoch + 1, optimizer.param_groups[0]['lr'])) if min_loss > running_loss / nb_batches: min_loss = running_loss / nb_batches is_best = True early_stop_count = 0 else: is_best = False early_stop_count += 1 state = { 'epoch': epoch + 1, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': running_loss / nb_batches, 'min_loss': min_loss } verbose = save_ckp(state, is_best, early_stop_count=early_stop_count, early_stop_patience=early_stop_patience, save_model_dir=model_path, best_dir=best_path, name=model_name) # summarize the training results of this epoch logger.info('The average training loss for this epoch is {}'.format( running_loss / nb_batches)) logger.info('The best training loss till now is {}'.format(min_loss)) logger.info( 'Validation dice loss: {}; Validation (avg) accuracy of the last timestep: {}' .format(val_loss, val_acc)) # save the training info every epoch logger.info('Writing the training info into file ...') val_info_file = os.path.join(intermidiate_data_save, '{}_val_info.json'.format(model_name)) with open(train_info_file, 'w') as fp: json.dump(train_info, fp) with open(val_info_file, 'w') as fp: json.dump(val_info, fp) for name, layer in net.named_parameters(): if layer.requires_grad: writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(), epoch) writer.add_histogram(name + '_data', layer.cpu().data.numpy(), epoch) if verbose: logger.info( 'The validation loss has not improved for {} epochs, training will stop here.' .format(early_stop_patience)) break loss_plot(train_info_file, name=model_name) logger.info('finish training!') return
def evulate_dir_nii_weakly_new(): ''' 新版本的评估代码 :return: ''' from metrics import dice, IoU from datasets.medicalImage import convertCase2PNGs, image_expand nii_dir = '/home/give/Documents/dataset/ISBI2017/Training_Batch_1' save_dir = '/home/give/Documents/dataset/ISBI2017/weakly_label_segmentation_V4/Batch_1/DLSC_0/niis' # restore_path = '/home/give/PycharmProjects/weakly_label_segmentation/logs/ISBI2017_V2/1s_agumentation_weakly-upsampling-2/model.ckpt-168090' restore_path = '/home/give/PycharmProjects/weakly_label_segmentation/logs/ISBI2017_V2/1s_agumentation_weakly_V3-upsampling-1/model.ckpt-167824' nii_parent_dir = os.path.dirname(nii_dir) with tf.name_scope('test'): image = tf.placeholder(dtype=tf.float32, shape=[None, None, 3]) image_shape_placeholder = tf.placeholder(tf.int32, shape=[2]) input_shape_placeholder = tf.placeholder(tf.int32, shape=[2]) processed_image = segmentation_preprocessing.segmentation_preprocessing( image, None, None, out_shape=input_shape_placeholder, is_training=False) b_image = tf.expand_dims(processed_image, axis=0) net = UNetBlocksMS.UNet(b_image, None, None, is_training=False, decoder=FLAGS.decoder, update_center_flag=FLAGS.update_center, batch_size=2, init_center_value=None, update_center_strategy=1, num_centers_k=FLAGS.num_centers_k, full_annotation_flag=False, output_shape_tensor=input_shape_placeholder) # print slim.get_variables_to_restore() global_step = slim.get_or_create_global_step() sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) if FLAGS.gpu_memory_fraction < 0: sess_config.gpu_options.allow_growth = True elif FLAGS.gpu_memory_fraction > 0: sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction # Variables to restore: moving avg. or normal weights. if FLAGS.using_moving_average: variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay) variables_to_restore = variable_averages.variables_to_restore() variables_to_restore[global_step.op.name] = global_step else: variables_to_restore = slim.get_variables_to_restore() saver = tf.train.Saver() nii_pathes = glob(os.path.join(nii_dir, 'volume-*.nii')) checkpoint = restore_path # pixel_recovery_features = tf.image.resize_images(net.pixel_recovery_features, image_shape_placeholder) with tf.Session(config=sess_config) as sess: saver.restore(sess, checkpoint) global_gt = [] global_pred = [] global_pred_grabcut = [] case_dices = [] case_IoUs = [] case_dices_grabcut = [] case_IoUs_grabcut = [] for iter, nii_path in enumerate(nii_pathes): # if os.path.basename(nii_path) in ['volume-15.nii', 'volume-25.nii']: # continue if os.path.basename(nii_path) != 'volume-0.nii': continue nii_path_basename = os.path.basename(nii_path) pred_dir = os.path.join( save_dir, nii_path_basename.split('.')[0].split('-')[1], 'pred') pred_vis_dir = os.path.join( save_dir, nii_path_basename.split('.')[0].split('-')[1], 'pred_vis') recovery_img_dir = os.path.join( save_dir, nii_path_basename.split('.')[0].split('-')[1], 'recovery_img') if not os.path.exists(pred_dir): os.makedirs(pred_dir) if not os.path.exists(pred_vis_dir): os.makedirs(pred_vis_dir) if not os.path.exists(recovery_img_dir): os.makedirs(recovery_img_dir) seg_path = os.path.join( nii_dir, 'segmentation-' + nii_path.split('.')[0].split('-')[1] + '.nii') case_preds = [] case_gts = [] case_preds_grabcut = [] # case_recover_features = [] print(nii_path, seg_path) imgs, tumor_masks, liver_masks, tumor_weak_masks = convertCase2PNGs( nii_path, seg_path, save_dir=None) print(len(imgs), len(tumor_masks), len(liver_masks), len(tumor_masks)) for slice_idx, (image_data, liver_mask, whole_mask) in enumerate( zip(imgs, liver_masks, tumor_masks)): pixel_cls_scores_ms = [] pixel_recover_feature_ms = [] for single_scale in scales: pixel_recover_feature, pixel_cls_scores, b_image_v, global_step_v, net_centers = sess.run( [ net.pixel_recovery_features, net.pixel_cls_scores, b_image, global_step, net.centers ], feed_dict={ image: image_data, image_shape_placeholder: np.shape(image_data)[:2], input_shape_placeholder: single_scale }) pixel_cls_scores_ms.append( cv2.resize(pixel_cls_scores[0, :, :, 1], tuple(np.shape(image_data)[:2][::-1]))) pixel_recover_feature_ms.append(pixel_recover_feature[0]) del pixel_recover_feature pixel_cls_scores = np.mean(pixel_cls_scores_ms, axis=0) pixel_recover_feature = np.mean(pixel_recover_feature_ms, axis=0) # case_recover_features.append(pixel_recover_feature) if np.sum(whole_mask) != 0: pred = np.asarray(pixel_cls_scores > 0.6, np.uint8) # 开操作 先腐蚀,后膨胀 # 闭操作 先膨胀,后腐蚀 # pred = close_operation(pred, kernel_size=3) pred = open_operation(pred, kernel_size=3) pred = fill_region(pred) from grabcut import grabcut pred_grabcut = np.asarray( image_expand(pred, kernel_size=5), np.uint8) xs, ys = np.where(pred_grabcut == 1) print(np.shape(pred_grabcut)) if len(xs) == 0: pred_grabcut = np.zeros_like(whole_mask) print(np.min(pred_grabcut), np.max(pred_grabcut), np.sum(pred_grabcut)) else: min_xs = np.min(xs) max_xs = np.max(xs) min_ys = np.min(ys) max_ys = np.max(ys) pred_grabcut = grabcut( np.asarray(image_data * 255., np.uint8), [min_xs, min_ys, max_xs, max_ys]) pred_grabcut = np.asarray(pred_grabcut == 255, np.uint8) print(np.unique(pred_grabcut)) cv2.imwrite('./tmp/%d_gt.png' % slice_idx, np.asarray(whole_mask * 200, np.uint8)) cv2.imwrite('./tmp/%d_pred.png' % slice_idx, np.asarray(pred * 200, np.uint8)) cv2.imwrite('./tmp/%d_pred_grabcut.png' % slice_idx, np.asarray(pred_grabcut * 200, np.uint8)) else: pred = np.zeros_like(whole_mask) pred_grabcut = np.zeros_like(whole_mask) global_gt.append(whole_mask) case_gts.append(whole_mask) case_preds.append(pred) case_preds_grabcut.append(pred_grabcut) global_pred.append(pred) global_pred_grabcut.append(pred_grabcut) print '%d / %d: %s' % (slice_idx + 1, len( imgs), os.path.basename(nii_path)), np.shape( pixel_cls_scores), np.max(pixel_cls_scores), np.min( pixel_cls_scores), np.shape(pixel_recover_feature) del pixel_recover_feature, pixel_recover_feature_ms gc.collect() case_dice = dice(case_gts, case_preds) case_IoU = IoU(case_gts, case_preds) case_dice_grabcut = dice(case_gts, case_preds_grabcut) case_IoU_grabcut = IoU(case_gts, case_preds_grabcut) print('case dice: ', case_dice) print('case IoU: ', case_IoU) print('case dice grabcut: ', case_dice_grabcut) print('case IoU grabcut: ', case_IoU_grabcut) case_dices.append(case_dice) case_IoUs.append(case_IoU) case_dices_grabcut.append(case_dice_grabcut) case_IoUs_grabcut.append(case_IoU_grabcut) print 'global dice is ', dice(global_gt, global_pred) print 'global IoU is ', IoU(global_gt, global_pred) print('mean of case dice is ', np.mean(case_dices)) print('mean of case IoU is ', np.mean(case_IoUs)) print('mean of case dice (grabcut) is ', np.mean(case_dices_grabcut)) print('mean of case IoU (grabcut) is ', np.mean(case_IoUs_grabcut)) print('global dice (grabcut) is ', dice(global_gt, global_pred_grabcut)) print('global IoU (grabcut) is ', IoU(global_gt, global_pred_grabcut))
def validation(trained_net, val_set, criterion, device, batch_size, model_type=None, softmax=True, ignore_idx=None): ''' used for evaluation during training phase params trained_net: trained U-net params val_set: validation dataset params criterion: loss function params device: cpu or gpu ''' n_val = len(val_set) val_loader = val_set.load() tot = 0 acc = 0 dice_score_bg = 0 dice_score_wm = 0 dice_score_gm = 0 dice_score_csf = 0 dice_score_tm = 0 val_info = {'bg': [], 'wm': [], 'gm': [], 'csf': [], 'tm': []} with tqdm(total=n_val, desc='Validation round', unit='patch', leave=False) as pbar: with torch.no_grad(): for i, sample in enumerate(val_loader): images, segs = sample['image'].to( device=device), sample['seg'].to(device=device) if model_type == 'SkipDenseSeg' and not softmax: segs = segs.long() batch, time, channel, z, y, x = images.shape images = images.view(-1, channel, z, y, x) segs = segs.view(-1, z, y, x) preds = trained_net(images) val_loss = criterion(preds, segs) _, preds = torch.max(preds, 1) dice_score = dice(preds.data.cpu(), segs.data.cpu(), ignore_idx) dice_score_bg += dice_score['bg'] dice_score_wm += dice_score['wm'] dice_score_gm += dice_score['gm'] dice_score_csf += dice_score['csf'] dice_score_tm += dice_score['tm'] tot += val_loss.detach().item() acc += dice_score['avg'] pbar.set_postfix( **{ 'validation loss': val_loss.detach().item(), 'val_acc_avg': dice_score['avg'] }) pbar.update(images.shape[0]) val_info['bg'] = dice_score_bg / (np.ceil(n_val / batch_size)) val_info['wm'] = dice_score_wm / (np.ceil(n_val / batch_size)) val_info['gm'] = dice_score_gm / (np.ceil(n_val / batch_size)) val_info['csf'] = dice_score_csf / (np.ceil(n_val / batch_size)) val_info['tm'] = dice_score_tm / (np.ceil(n_val / batch_size)) return tot / (np.ceil(n_val / batch_size)), acc / (np.ceil( n_val / batch_size)), val_info
def train(args): torch.cuda.manual_seed(1) torch.manual_seed(1) # user defined model_name = args.model_name model_type = args.model_type loss_func = args.loss world_size = args.world_size rank = args.rank base_channel = args.base_channels crop_size = args.crop_size ignore_idx = args.ignore_idx epochs = args.epoch # system setup config_file = 'config.yaml' config = load_config(config_file) labels = config['PARAMETERS']['labels'] root_path = config['PATH']['model_root'] model_dir = config['PATH']['save_ckp'] best_dir = config['PATH']['save_best_model'] output_channels = int(config['PARAMETERS']['output_channels']) batch_size = int(config['PARAMETERS']['batch_size']) is_best = bool(config['PARAMETERS']['is_best']) is_resume = bool(config['PARAMETERS']['resume']) patience = int(config['PARAMETERS']['patience']) time_step = int(config['PARAMETERS']['time_step']) num_workers = int(config['PARAMETERS']['num_workers']) early_stop_patience = int(config['PARAMETERS']['early_stop_patience']) pad_method = config['PARAMETERS']['pad_method'] lr = int(config['PARAMETERS']['lr']) optimizer = config['PARAMETERS']['optimizer'] softmax = True modalities = ['flair', 't1', 't1gd', 't2'] input_modalites = len(modalities) # build up dirs model_path = os.path.join(root_path, model_dir) best_path = os.path.join(root_path, best_dir) intermidiate_data_save = os.path.join(root_path, 'train_newdata', model_name) train_info_file = os.path.join(intermidiate_data_save, '{}_train_info.json'.format(model_name)) log_path = os.path.join(root_path, 'logfiles') if not os.path.exists(model_path): os.mkdir(model_path) if not os.path.exists(best_path): os.mkdir(best_path) if not os.path.exists(intermidiate_data_save): os.makedirs(intermidiate_data_save) if not os.path.exists(log_path): os.mkdir(log_path) log_name = model_name + '_' + config['PATH']['log_file'] logger = logfile(os.path.join(log_path, log_name)) logger.info('Dataset is loading ...') writer = SummaryWriter('ProcessVisu/%s' % model_name) logger.info('patch size: {}'.format(crop_size)) # load training set and validation set data_class = data_split() train, val, test = data_construction(data_class) train_dict = time_parser(train, time_patch=time_step) val_dict = time_parser(val, time_patch=time_step) # groups = 4 if model_type == 'UNet': net = init_U_Net(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'ResUNet': net = ResUNet(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'DResUNet': net = DResUNet(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'direct_concat': net = U_Net_direct_concat(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'Inception': net = Inception_UNet(input_modalites, output_channels, base_channel, softmax) elif model_type == 'Simple_Inception': net = Simplified_Inception_UNet(input_modalites, output_channels, base_channel, softmax) # device setup device = 'cuda:0' if torch.cuda.is_available() else 'cpu' net.to(device) # print model structure summary(net, input_size=(input_modalites, crop_size, crop_size, crop_size)) dummy_input = torch.rand(1, input_modalites, crop_size, crop_size, crop_size).to(device) writer.add_graph(net, (dummy_input, )) # loss and optimizer setup if loss_func == 'Dice' and softmax: criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx) elif loss_func == 'GDice' and softmax: criterion = GneralizedDiceLoss(labels=labels) elif loss_func == 'CrossEntropy': criterion = WeightedCrossEntropyLoss(labels=labels) if not softmax: criterion = nn.CrossEntropyLoss().cuda() else: raise NotImplementedError() if optimizer == 'adam': optimizer = optim.Adam(net.parameters()) elif optimizer == 'sgd': optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr, weight_decay=1e-5) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) # net, optimizer = amp.initialize(net, optimizer, opt_level='O1') if torch.cuda.device_count() > 1: logger.info('{} GPUs avaliable'.format(torch.cuda.device_count())) torch.distributed.init_process_group( backend='nccl', init_method='tcp://127.0.0.1:38366', rank=rank, world_size=world_size) if distributed_is_initialized(): logger.info('distributed is initialized') net.to(device) net = nn.parallel.DistributedDataParallel(net) else: logger.info('data parallel') net = nn.DataParallel(net) net.to(device) min_loss = float('Inf') early_stop_count = 0 global_step = 0 start_epoch = 0 start_loss = 0 train_info = { 'train_loss': [], 'val_loss': [], 'label_0_acc': [], 'label_1_acc': [], 'label_2_acc': [], 'label_3_acc': [], 'label_4_acc': [] } if is_resume: try: # open previous check points ckp_path = os.path.join(model_path, '{}_model_ckp.pth.tar'.format(model_name)) net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp( ckp_path, net, optimizer, scheduler) # open previous training records with open(train_info_file) as f: train_info = json.load(f) logger.info( 'Training loss from last time is {}'.format(start_loss) + '\n' + 'Mininum training loss from last time is {}'.format(min_loss)) logger.info( 'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}' .format(train_info['label_0_acc'][-1], train_info['label_1_acc'][-1], train_info['label_2_acc'][-1], train_info['label_3_acc'][-1], train_info['label_4_acc'][-1])) # min_loss = float('Inf') except: logger.warning( 'No checkpoint available, strat training from scratch') # start training for epoch in range(start_epoch, epochs): # every epoch generate a new set of images train_set = data_loader(train_dict, batch_size=batch_size, key='train', num_works=num_workers, time_step=time_step, patch=crop_size, modalities=modalities, model_type='CNN') n_train = len(train_set) train_loader = train_set.load() val_set = data_loader(val_dict, batch_size=batch_size, key='val', num_works=num_workers, time_step=time_step, patch=crop_size, modalities=modalities, model_type='CNN') n_val = len(val_set) nb_batches = np.ceil(n_train / batch_size) n_total = n_train + n_val logger.info( '{} images will be used in total, {} for trainning and {} for validation' .format(n_total, n_train, n_val)) logger.info('Dataset loading finished!') # setup to train mode net.train() running_loss = 0 dice_score_label_0 = 0 dice_score_label_1 = 0 dice_score_label_2 = 0 dice_score_label_3 = 0 dice_score_label_4 = 0 logger.info('Training epoch {} will begin'.format(epoch + 1)) with tqdm(total=n_train, desc=f'Epoch {epoch+1}/{epochs}', unit='patch') as pbar: for i, data in enumerate(train_loader, 0): images, segs = data['image'].to(device), data['seg'].to(device) if model_type == 'SkipDenseSeg' and not softmax: segs = segs.long() # combine the batch and time step batch, time, channel, z, y, x = images.shape images = images.view(-1, channel, z, y, x) segs = segs.view(-1, z, y, x) # zero the parameter gradients optimizer.zero_grad() outputs = net(images) loss = criterion(outputs, segs) loss.backward() # with amp.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() optimizer.step() running_loss += loss.detach().item() _, preds = torch.max(outputs.data, 1) dice_score = dice(preds.data.cpu(), segs.data.cpu(), ignore_idx=ignore_idx) dice_score_label_0 += dice_score['bg'] dice_score_label_1 += dice_score['csf'] dice_score_label_2 += dice_score['gm'] dice_score_label_3 += dice_score['wm'] dice_score_label_4 += dice_score['tm'] # show progress bar pbar.set_postfix( **{ 'Training loss': loss.detach().item(), 'Training accuracy': dice_score['avg'] }) pbar.update(images.shape[0]) del images, segs global_step += 1 if global_step % nb_batches == 0: net.eval() val_loss, val_acc, val_info = validation( net, val_set, criterion, device, batch_size, model_type=model_type, softmax=softmax, ignore_idx=ignore_idx) train_info['train_loss'].append(running_loss / nb_batches) train_info['val_loss'].append(val_loss) train_info['label_0_acc'].append(dice_score_label_0 / nb_batches) train_info['label_1_acc'].append(dice_score_label_1 / nb_batches) train_info['label_2_acc'].append(dice_score_label_2 / nb_batches) train_info['label_3_acc'].append(dice_score_label_3 / nb_batches) train_info['label_4_acc'].append(dice_score_label_4 / nb_batches) # save bast trained model if model_type == 'SkipDenseSeg': scheduler.step() else: scheduler.step(val_loss) # debug for param_group in optimizer.param_groups: logger.info('%0.6f | %6d ' % (param_group['lr'], epoch)) if min_loss > running_loss / nb_batches + 1e-2: min_loss = running_loss / nb_batches is_best = True early_stop_count = 0 else: is_best = False early_stop_count += 1 # save the check point state = { 'epoch': epoch + 1, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': running_loss / nb_batches, 'min_loss': min_loss } verbose = save_ckp(state, is_best, early_stop_count=early_stop_count, early_stop_patience=early_stop_patience, save_model_dir=model_path, best_dir=best_path, name=model_name) # summarize the training results of this epoch logger.info('Average training loss of this epoch is {}'.format( running_loss / nb_batches)) logger.info('Best training loss till now is {}'.format(min_loss)) logger.info('Validation dice loss: {}; Validation accuracy: {}'.format( val_loss, val_acc)) # save the training info every epoch logger.info('Writing the training info into file ...') val_info_file = os.path.join(intermidiate_data_save, '{}_val_info.json'.format(model_name)) with open(train_info_file, 'w') as fp: json.dump(train_info, fp) with open(val_info_file, 'w') as fp: json.dump(val_info, fp) loss_plot(train_info_file, name=model_name) for name, layer in net.named_parameters(): writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(), epoch) writer.add_histogram(name + '_data', layer.cpu().data.numpy(), epoch) if verbose: logger.info( 'The validation loss has not improved for {} epochs, training will stop here.' .format(early_stop_patience)) break writer.close() logger.info('finish training!')
height: height + crop_size[1], width: width + crop_size[2]] += 1 whole_pred = whole_pred / count_used whole_pred = whole_pred[0, :, :, :, :] whole_pred = np.argmax(whole_pred, axis=0) print (whole_pred.shape, label.shape) label=label.transpose(0,2,1) print(whole_pred.shape, label.shape) dsc = [] print ('-------------------------') for i in range(1, num_classes): dsc_i = dice(whole_pred, label, i) dsc_i=round(dsc_i*100,2) dsc.append(dsc_i) datetime= time.strftime("%d/%m/%Y") print('Data | Note | class1| class2|class3| Avg.|') print('%s | %s | %2.2f | %2.2f | %2.2f | %2.2f |' % ( \ datetime, note, dsc[0], dsc[1], dsc[2], np.mean(dsc)))