def predict_CNN(args): config_file = 'config.yaml' config = load_config(config_file) input_modalites = int(config['PARAMETERS']['input_modalites']) output_channels = int(config['PARAMETERS']['output_channels']) root_path = config['PATH']['model_root'] best_dir = config['PATH']['save_best_model'] best_path = os.path.join(root_path, best_dir) model_type = args.net model_name = args.model_name crop_size = args.crop_size overlap_size = args.overlap_size base_channels = args.base_channels device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # load best trained model if model_name.startswith('UNet'): net = init_U_Net(input_modalites, output_channels, base_channels) elif model_name.startswith('direct'): net = U_Net_direct_concat(input_modalites, output_channels, base_channels) elif model_name.startswith('ResUNet'): net = ResUNet(input_modalites, output_channels, base_channels) elif model_name.startswith('DResUNet'): net = DResUNet(input_modalites, output_channels, base_channels) else: raise NotImplementedError() if distributed_is_initialized(): net = nn.DataParallel.DistributedDataParallel(net) else: net = nn.DataParallel(net) net.to(device) ckp_path = os.path.join(best_path, model_name + '_best_model.pth.tar') checkpoint = torch.load(ckp_path, map_location=device) net.load_state_dict(checkpoint['model_state_dict']) print_size_of_model(net) # predict data_class = data_split() train, val, test = data_construction(data_class) test_dict = time_parser(test) patient_id = [key for key in test_dict.keys()] modalities = ['flair', 't1', 't1gd', 't2'] Dice = {} CSF_Dice = [] GM_Dice = [] WM_Dice = [] TM_Dice = [] HD95 = {} CSF_HD = [] GM_HD = [] WM_HD = [] TM_HD = [] ASD = {} CSF_ASD = [] GM_ASD = [] WM_ASD = [] TM_ASD = [] for i in range(len(patient_id)): time_dict = test_dict[patient_id[i]] time_dict = sorted(time_dict.items(), key=lambda item: item[0]) predicted_masks = {} patient_inference_folder = os.path.join('inference_result', patient_id[i]) if not os.path.exists(patient_inference_folder): os.makedirs(patient_inference_folder) for time_point in time_dict: print('Predicting patient {} at time {}'.format( patient_id[i], time_point[0])) brain_mask = sitk.GetArrayFromImage( sitk.ReadImage(time_point[1]['brainmask'])) images = [ sitk.GetArrayFromImage(sitk.ReadImage(time_point[1][modality])) * brain_mask for modality in modalities ] images = np.stack(images) mask = sitk.ReadImage(time_point[1]['combined_fast']) image_shape = images.shape[-3:] inferenced_mask = predict(trained_net=net, image=images, image_shape=image_shape, crop_size=crop_size, overlap_size=overlap_size, model_type=model_type) Dice_score, HD95_score, ASD_score = plot_save( images[0], inferenced_mask, segmentation=mask, inference_folder=patient_inference_folder, model_name=args.model_name, inference_name=args.model_name + '_' + patient_id[i] + '_' + time_point[0], save_mask=True) CSF_Dice.append(Dice_score['csf']) GM_Dice.append(Dice_score['gm']) WM_Dice.append(Dice_score['wm']) TM_Dice.append(Dice_score['tm']) CSF_HD.append(HD95_score['csf']) GM_HD.append(HD95_score['gm']) WM_HD.append(HD95_score['wm']) TM_HD.append(HD95_score['tm']) CSF_ASD.append(ASD_score['csf']) GM_ASD.append(ASD_score['gm']) WM_ASD.append(ASD_score['wm']) TM_ASD.append(ASD_score['tm']) predicted_masks[time_point[0]] = inferenced_mask plot_volumn_dev(test_dict, patient_id[i], model_name=model_name, predicted_labels=predicted_masks) break Dice['csf'] = CSF_Dice Dice['gm'] = GM_Dice Dice['wm'] = WM_Dice Dice['tm'] = TM_Dice HD95['csf'] = CSF_HD HD95['gm'] = GM_HD HD95['wm'] = WM_HD HD95['tm'] = TM_HD ASD['csf'] = CSF_ASD ASD['gm'] = GM_ASD ASD['wm'] = WM_ASD ASD['tm'] = TM_ASD dice_dir = os.path.join('inference_result', 'dice_' + args.model_name) HD_dir = os.path.join('inference_result', 'HD95_' + args.model_name) ASD_dir = os.path.join('inference_result', 'ASD_' + args.model_name) if not os.path.exists(dice_dir): os.mkdir(dice_dir) if not os.path.exists(HD_dir): os.mkdir(HD_dir) if not os.path.exists(ASD_dir): os.mkdir(ASD_dir) dice_file = os.path.join(dice_dir, 'dice.json') HD_file = os.path.join(HD_dir, 'HD95.json') ASD_file = os.path.join(ASD_dir, 'ASD.json') with open(dice_file, 'w') as f: json.dump(Dice, f) with open(HD_file, 'w') as f: json.dump(HD95, f) with open(ASD_file, 'w') as f: json.dump(ASD, f) box_plot(Dice, dice_dir, metric='Dice') box_plot(HD95, HD_dir, metric='Hausdorf') box_plot(ASD, ASD_dir, metric='ASD')
def Longitudinal_predict(args): config_file = 'config.yaml' config = load_config(config_file) input_modalites = int(config['PARAMETERS']['input_modalites']) output_channels = int(config['PARAMETERS']['output_channels']) conv_type = config['PARAMETERS']['lstm_convtype'] connect = config['PARAMETERS']['connect'] root_path = config['PATH']['model_root'] best_dir = config['PATH']['save_best_model'] best_path = os.path.join(root_path, best_dir) model_name = args.model_name crop_size = args.crop_size overlap_size = args.overlap_size base_channels = args.base_channels lstm_backbone = args.lstmbase unet_backbone = args.unetbase layer_num = args.layer_num return_sequence = args.return_sequence nb_shortcut = args.nb_shortcut is_pretrain = args.is_pretrain inference_step = 3 if model_name.startswith('Back'): net = BackLSTM(input_dim=input_modalites, hidden_dim=base_channels, output_dim=output_channels, kernel_size=3, num_layers=layer_num, conv_type=conv_type, lstm_backbone=lstm_backbone, unet_module=unet_backbone, base_channel=base_channels, return_sequence=return_sequence, is_pretrain=is_pretrain) elif model_name.startswith('CenterLSTM'): net = CenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channels, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif model_name.startswith('CenterDenseBiLSTM') or model_name.startswith( 'CenterNormalBiLSTM'): net = BiCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channels, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif model_name.startswith( 'BiDirectCenterNormal') or model_name.startswith( 'BiDirectCenterDense'): net = BiDirectCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channels, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif model_name.startswith('BiResCenterNormal') or model_name.startswith( 'BiResCenterDense'): net = BiResCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channels, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif model_name.startswith('Shortcut'): net = ShortcutLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channels, num_layers=layer_num, num_connects=nb_shortcut, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) else: raise NotImplementedError() device = 'cuda:0' if torch.cuda.is_available() else 'cpu' if distributed_is_initialized(): net = nn.DataParallel.DistributedDataParallel(net) else: net = nn.DataParallel(net) net.to(device) ckp_path = os.path.join(best_path, model_name + '_best_model.pth.tar') checkpoint = torch.load(ckp_path, map_location=device) net.load_state_dict(checkpoint['model_state_dict']) print('{} size is: '.format(model_name)) print_size_of_model(net) # predict data_class = data_split() train, val, test = data_construction(data_class) test_dict = time_parser(test) patient_id = [key for key in test_dict.keys()] # patient_id = ['EGD-0505'] modalities = ['flair', 't1', 't1gd', 't2'] Dice = {} CSF_Dice = [] GM_Dice = [] WM_Dice = [] TM_Dice = [] HD95 = {} CSF_HD = [] GM_HD = [] WM_HD = [] TM_HD = [] ASD = {} CSF_ASD = [] GM_ASD = [] WM_ASD = [] TM_ASD = [] for i in range(len(patient_id)): predicted_masks = {} time_dict = test_dict[patient_id[i]] time_dict = sorted(time_dict.items(), key=lambda item: item[0]) tot_timesteps = len(time_dict) fold = int(np.ceil(tot_timesteps / inference_step)) patient_inference_folder = os.path.join('inference_result', patient_id[i]) if not os.path.exists(patient_inference_folder): os.makedirs(patient_inference_folder) for j in range(fold): selected_time_points = time_dict[j * inference_step:(j + 1) * inference_step] image_stack = [] seg = [] time_list = [] # predict per group for time_point in selected_time_points: time_list.append(time_point[0]) brain_mask = sitk.GetArrayFromImage( sitk.ReadImage(time_point[1]['brainmask'])) images = [ sitk.GetArrayFromImage( sitk.ReadImage(time_point[1][modality])) * brain_mask for modality in modalities ] images = np.stack(images) image_stack.append(images) mask = sitk.ReadImage(time_point[1]['combined_fast']) seg.append(mask) image_shape = images.shape[-3:] inferenced_mask = predict(trained_net=net, image=np.stack(image_stack), image_shape=image_shape, crop_size=crop_size, overlap_size=overlap_size, model_type='RNN') for k in range(inferenced_mask.shape[0]): Dice_score, HD95_score, ASD_score = plot_save( image_stack[k][0], inferenced_mask[k], segmentation=seg[k], inference_folder=patient_inference_folder, model_name=args.model_name, inference_name=args.model_name + '_' + patient_id[i] + '_' + time_list[k], save_mask=True) predicted_masks[time_list[k]] = inferenced_mask[k] CSF_Dice.append(Dice_score['csf']) GM_Dice.append(Dice_score['gm']) WM_Dice.append(Dice_score['wm']) TM_Dice.append(Dice_score['tm']) CSF_HD.append(HD95_score['csf']) GM_HD.append(HD95_score['gm']) WM_HD.append(HD95_score['wm']) TM_HD.append(HD95_score['tm']) CSF_ASD.append(ASD_score['csf']) GM_ASD.append(ASD_score['gm']) WM_ASD.append(ASD_score['wm']) TM_ASD.append(ASD_score['tm']) plot_volumn_dev(test_dict, patient_id[i], model_name=model_name, predicted_labels=predicted_masks) Dice['csf'] = CSF_Dice Dice['gm'] = GM_Dice Dice['wm'] = WM_Dice Dice['tm'] = TM_Dice HD95['csf'] = CSF_HD HD95['gm'] = GM_HD HD95['wm'] = WM_HD HD95['tm'] = TM_HD ASD['csf'] = CSF_ASD ASD['gm'] = GM_ASD ASD['wm'] = WM_ASD ASD['tm'] = TM_ASD dice_dir = os.path.join('inference_result', 'dice_' + args.model_name) HD_dir = os.path.join('inference_result', 'HD95_' + args.model_name) ASD_dir = os.path.join('inference_result', 'ASD_' + args.model_name) if not os.path.exists(dice_dir): os.mkdir(dice_dir) if not os.path.exists(HD_dir): os.mkdir(HD_dir) if not os.path.exists(ASD_dir): os.mkdir(ASD_dir) dice_file = os.path.join(dice_dir, 'dice.json') HD_file = os.path.join(HD_dir, 'HD95.json') ASD_file = os.path.join(ASD_dir, 'ASD.json') with open(dice_file, 'w') as f: json.dump(Dice, f) with open(HD_file, 'w') as f: json.dump(HD95, f) with open(ASD_file, 'w') as f: json.dump(ASD, f) box_plot(Dice, dice_dir, metric='Dice') box_plot(HD95, HD_dir, metric='HD') box_plot(ASD, ASD_dir, metric='ASD')
def train(args): torch.cuda.manual_seed(1) torch.manual_seed(1) # user defined parameters model_name = args.model_name model_type = args.model_type lstm_backbone = args.lstmbase unet_backbone = args.unetbase layer_num = args.layer_num nb_shortcut = args.nb_shortcut loss_fn = args.loss_fn world_size = args.world_size rank = args.rank base_channel = args.base_channels crop_size = args.crop_size ignore_idx = args.ignore_idx return_sequence = args.return_sequence variant = args.LSTM_variant epochs = args.epoch is_pretrain = args.is_pretrain # system setup parameters config_file = 'config.yaml' config = load_config(config_file) labels = config['PARAMETERS']['labels'] root_path = config['PATH']['model_root'] model_dir = config['PATH']['save_ckp'] best_dir = config['PATH']['save_best_model'] input_modalites = int(config['PARAMETERS']['input_modalites']) output_channels = int(config['PARAMETERS']['output_channels']) batch_size = int(config['PARAMETERS']['batch_size']) is_best = bool(config['PARAMETERS']['is_best']) is_resume = bool(config['PARAMETERS']['resume']) patience = int(config['PARAMETERS']['patience']) time_step = int(config['PARAMETERS']['time_step']) num_workers = int(config['PARAMETERS']['num_workers']) early_stop_patience = int(config['PARAMETERS']['early_stop_patience']) lr = int(config['PARAMETERS']['lr']) optimizer = config['PARAMETERS']['optimizer'] connect = config['PARAMETERS']['connect'] conv_type = config['PARAMETERS']['lstm_convtype'] # build up dirs model_path = os.path.join(root_path, model_dir) best_path = os.path.join(root_path, best_dir) intermidiate_data_save = os.path.join(root_path, 'train_newdata', model_name) train_info_file = os.path.join(intermidiate_data_save, '{}_train_info.json'.format(model_name)) log_path = os.path.join(root_path, 'logfiles') if not os.path.exists(model_path): os.mkdir(model_path) if not os.path.exists(best_path): os.mkdir(best_path) if not os.path.exists(intermidiate_data_save): os.makedirs(intermidiate_data_save) if not os.path.exists(log_path): os.mkdir(log_path) log_name = model_name + '_' + config['PATH']['log_file'] logger = logfile(os.path.join(log_path, log_name)) logger.info('labels {} are ignored'.format(ignore_idx)) logger.info('Dataset is loading ...') writer = SummaryWriter('ProcessVisu/%s' % model_name) # load training set and validation set data_class = data_split() train, val, test = data_construction(data_class) train_dict = time_parser(train, time_patch=time_step) val_dict = time_parser(val, time_patch=time_step) # LSTM initilization if model_type == 'LSTM': net = LSTMSegNet(lstm_backbone=lstm_backbone, input_dim=input_modalites, output_dim=output_channels, hidden_dim=base_channel, kernel_size=3, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence) elif model_type == 'UNet_LSTM': if variant == 'back': net = BackLSTM(input_dim=input_modalites, hidden_dim=base_channel, output_dim=output_channels, kernel_size=3, num_layers=layer_num, conv_type=conv_type, lstm_backbone=lstm_backbone, unet_module=unet_backbone, base_channel=base_channel, return_sequence=return_sequence, is_pretrain=is_pretrain) logger.info( 'the pretrained status of backbone is {}'.format(is_pretrain)) elif variant == 'center': net = CenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'bicenter': net = BiCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'directcenter': net = DirectCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'bidirectcenter': net = BiDirectCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'rescenter': net = ResCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'birescenter': net = BiResCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'shortcut': net = ShortcutLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, num_connects=nb_shortcut, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) else: raise NotImplementedError() # loss and optimizer setup if loss_fn == 'Dice': criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx) elif loss_fn == 'GDice': criterion = GneralizedDiceLoss(labels=labels) elif loss_fn == 'WCE': criterion = WeightedCrossEntropyLoss(labels=labels) else: raise NotImplementedError() if optimizer == 'adam': optimizer = optim.Adam(net.parameters(), lr=0.001) # optimizer = optim.Adam(net.parameters()) elif optimizer == 'sgd': optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) # device setup device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # net, optimizer = amp.initialize(net, optimizer, opt_level="O1") if torch.cuda.device_count() > 1: torch.distributed.init_process_group( backend='nccl', init_method='tcp://127.0.0.1:38366', rank=rank, world_size=world_size) if distributed_is_initialized(): print('distributed is initialized') net.to(device) net = nn.parallel.DistributedDataParallel(net, find_unused_parameters=True) else: print('data parallel') net = nn.DataParallel(net) net.to(device) min_loss = float('Inf') early_stop_count = 0 global_step = 0 start_epoch = 0 start_loss = 0 train_info = { 'train_loss': [], 'val_loss': [], 'label_0_acc': [], 'label_1_acc': [], 'label_2_acc': [], 'label_3_acc': [], 'label_4_acc': [] } if is_resume: try: # open previous check points ckp_path = os.path.join(model_path, '{}_model_ckp.pth.tar'.format(model_name)) net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp( ckp_path, net, optimizer, scheduler) # open previous training records with open(train_info_file) as f: train_info = json.load(f) logger.info( 'Training loss from last time is {}'.format(start_loss) + '\n' + 'Mininum training loss from last time is {}'.format(min_loss)) logger.info( 'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}' .format(train_info['label_0_acc'][-1], train_info['label_1_acc'][-1], train_info['label_2_acc'][-1], train_info['label_3_acc'][-1], train_info['label_4_acc'][-1])) except: logger.warning( 'No checkpoint available, strat training from scratch') for epoch in range(start_epoch, epochs): train_set = data_loader(train_dict, batch_size=batch_size, key='train', num_works=num_workers, time_step=time_step, patch=crop_size, model_type='RNN') n_train = len(train_set) val_set = data_loader(val_dict, batch_size=batch_size, key='val', num_works=num_workers, time_step=time_step, patch=crop_size, model_type='CNN') n_val = len(val_set) logger.info('Dataset loading finished!') nb_batches = np.ceil(n_train / batch_size) n_total = n_train + n_val logger.info( '{} images will be used in total, {} for trainning and {} for validation' .format(n_total, n_train, n_val)) train_loader = train_set.load() # setup to train mode net.train() running_loss = 0 dice_score_label_0 = 0 dice_score_label_1 = 0 dice_score_label_2 = 0 dice_score_label_3 = 0 dice_score_label_4 = 0 logger.info('Training epoch {} will begin'.format(epoch + 1)) with tqdm(total=n_train, desc=f'Epoch {epoch+1}/{epochs}', unit='patch') as pbar: for i, data in enumerate(train_loader, 0): # i : patient images, segs = data['image'].to(device), data['seg'].to(device) outputs = net(images) loss = criterion(outputs, segs) loss.backward() optimizer.step() # if i == 0: # in_images = images.detach().cpu().numpy()[0] # in_segs = segs.detach().cpu().numpy()[0] # in_pred = outputs.detach().cpu().numpy()[0] # heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=model_name, epoch=epoch+1, is_train=True) running_loss += loss.detach().item() outputs = outputs.view(-1, outputs.shape[-4], outputs.shape[-3], outputs.shape[-2], outputs.shape[-1]) segs = segs.view(-1, segs.shape[-3], segs.shape[-2], segs.shape[-1]) _, preds = torch.max(outputs.data, 1) dice_score = dice(preds.data.cpu(), segs.data.cpu(), ignore_idx=None) dice_score_label_0 += dice_score['bg'] dice_score_label_1 += dice_score['csf'] dice_score_label_2 += dice_score['gm'] dice_score_label_3 += dice_score['wm'] dice_score_label_4 += dice_score['tm'] # show progress bar pbar.set_postfix( **{ 'training loss': loss.detach().item(), 'Training accuracy': dice_score['avg'] }) pbar.update(images.shape[0]) global_step += 1 if global_step % nb_batches == 0: net.eval() val_loss, val_acc, val_info = validation(net, val_set, criterion, device, batch_size, ignore_idx=None, name=model_name, epoch=epoch + 1) net.train() train_info['train_loss'].append(running_loss / nb_batches) train_info['val_loss'].append(val_loss) train_info['label_0_acc'].append(dice_score_label_0 / nb_batches) train_info['label_1_acc'].append(dice_score_label_1 / nb_batches) train_info['label_2_acc'].append(dice_score_label_2 / nb_batches) train_info['label_3_acc'].append(dice_score_label_3 / nb_batches) train_info['label_4_acc'].append(dice_score_label_4 / nb_batches) # save bast trained model scheduler.step(running_loss / nb_batches) logger.info('Epoch: {}, LR: {}'.format( epoch + 1, optimizer.param_groups[0]['lr'])) if min_loss > running_loss / nb_batches: min_loss = running_loss / nb_batches is_best = True early_stop_count = 0 else: is_best = False early_stop_count += 1 state = { 'epoch': epoch + 1, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': running_loss / nb_batches, 'min_loss': min_loss } verbose = save_ckp(state, is_best, early_stop_count=early_stop_count, early_stop_patience=early_stop_patience, save_model_dir=model_path, best_dir=best_path, name=model_name) # summarize the training results of this epoch logger.info('The average training loss for this epoch is {}'.format( running_loss / nb_batches)) logger.info('The best training loss till now is {}'.format(min_loss)) logger.info( 'Validation dice loss: {}; Validation (avg) accuracy of the last timestep: {}' .format(val_loss, val_acc)) # save the training info every epoch logger.info('Writing the training info into file ...') val_info_file = os.path.join(intermidiate_data_save, '{}_val_info.json'.format(model_name)) with open(train_info_file, 'w') as fp: json.dump(train_info, fp) with open(val_info_file, 'w') as fp: json.dump(val_info, fp) for name, layer in net.named_parameters(): if layer.requires_grad: writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(), epoch) writer.add_histogram(name + '_data', layer.cpu().data.numpy(), epoch) if verbose: logger.info( 'The validation loss has not improved for {} epochs, training will stop here.' .format(early_stop_patience)) break loss_plot(train_info_file, name=model_name) logger.info('finish training!') return
ax2.xaxis.label.set_visible(False) formatter = ticker.FormatStrFormatter('$%1.2f') ax2.xaxis.set_major_formatter(formatter) # ax2.set_xticklabels(['3D U-Net', '3D Res U-Net', '3D DRes U-Net', '3D DC U-Net'], fontsize=15, rotation=30) ax2.set_xticklabels(['U-Net backbone', 'Res U-Net backbone', 'DC U-Net backbone'], fontsize=15, rotation=15) ax2.set_ylabel('Distance/voxel', fontsize=15) ax2.set_title('Glioma', fontsize=20) # plt.suptitle('{} of CNN models'.format(dist.split('.')[0][:2]), fontsize=25) plt.savefig('4Dcomp{}_boxplot.png'.format(dist.split('.')[0][:2])) if __name__ == '__main__': labels = [0, 1, 2, 3, 4] data_class = data_split() train, val, test = data_construction(data_class) test_dict = time_parser(test) patient_id = [key for key in test_dict.keys()] model_name_1 = 'BiDirectCenterNormalLSTM1layer-dcunet-p64-4x3' model_name_2 = 'direct-UNet-p64-newdata-oriinput' # plot(model_name_1, model_name_2) avg_transition_matrix(test_dict, model_name_1, model_name_2) # transition_matrix(test_dict, PATIENT, model_name_1, model_name_2) # longitudinal_HD(test_dict, PATIENT, model_name_1, model_name_2) # transition_heatmap(test_dict, PATIENT, model_name_1, model_name_2) # dice1 = 'dice_UNet-p64-b4-newdata-oriinput' # dice1 = 'dice_ResUNet-p64-b4-newdata-oriinput' # dice3 = 'dice_DResUNet-p64-b4-newdata-oriinput'
def train(args): torch.cuda.manual_seed(1) torch.manual_seed(1) # user defined model_name = args.model_name model_type = args.model_type loss_func = args.loss world_size = args.world_size rank = args.rank base_channel = args.base_channels crop_size = args.crop_size ignore_idx = args.ignore_idx epochs = args.epoch # system setup config_file = 'config.yaml' config = load_config(config_file) labels = config['PARAMETERS']['labels'] root_path = config['PATH']['model_root'] model_dir = config['PATH']['save_ckp'] best_dir = config['PATH']['save_best_model'] output_channels = int(config['PARAMETERS']['output_channels']) batch_size = int(config['PARAMETERS']['batch_size']) is_best = bool(config['PARAMETERS']['is_best']) is_resume = bool(config['PARAMETERS']['resume']) patience = int(config['PARAMETERS']['patience']) time_step = int(config['PARAMETERS']['time_step']) num_workers = int(config['PARAMETERS']['num_workers']) early_stop_patience = int(config['PARAMETERS']['early_stop_patience']) pad_method = config['PARAMETERS']['pad_method'] lr = int(config['PARAMETERS']['lr']) optimizer = config['PARAMETERS']['optimizer'] softmax = True modalities = ['flair', 't1', 't1gd', 't2'] input_modalites = len(modalities) # build up dirs model_path = os.path.join(root_path, model_dir) best_path = os.path.join(root_path, best_dir) intermidiate_data_save = os.path.join(root_path, 'train_newdata', model_name) train_info_file = os.path.join(intermidiate_data_save, '{}_train_info.json'.format(model_name)) log_path = os.path.join(root_path, 'logfiles') if not os.path.exists(model_path): os.mkdir(model_path) if not os.path.exists(best_path): os.mkdir(best_path) if not os.path.exists(intermidiate_data_save): os.makedirs(intermidiate_data_save) if not os.path.exists(log_path): os.mkdir(log_path) log_name = model_name + '_' + config['PATH']['log_file'] logger = logfile(os.path.join(log_path, log_name)) logger.info('Dataset is loading ...') writer = SummaryWriter('ProcessVisu/%s' % model_name) logger.info('patch size: {}'.format(crop_size)) # load training set and validation set data_class = data_split() train, val, test = data_construction(data_class) train_dict = time_parser(train, time_patch=time_step) val_dict = time_parser(val, time_patch=time_step) # groups = 4 if model_type == 'UNet': net = init_U_Net(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'ResUNet': net = ResUNet(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'DResUNet': net = DResUNet(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'direct_concat': net = U_Net_direct_concat(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'Inception': net = Inception_UNet(input_modalites, output_channels, base_channel, softmax) elif model_type == 'Simple_Inception': net = Simplified_Inception_UNet(input_modalites, output_channels, base_channel, softmax) # device setup device = 'cuda:0' if torch.cuda.is_available() else 'cpu' net.to(device) # print model structure summary(net, input_size=(input_modalites, crop_size, crop_size, crop_size)) dummy_input = torch.rand(1, input_modalites, crop_size, crop_size, crop_size).to(device) writer.add_graph(net, (dummy_input, )) # loss and optimizer setup if loss_func == 'Dice' and softmax: criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx) elif loss_func == 'GDice' and softmax: criterion = GneralizedDiceLoss(labels=labels) elif loss_func == 'CrossEntropy': criterion = WeightedCrossEntropyLoss(labels=labels) if not softmax: criterion = nn.CrossEntropyLoss().cuda() else: raise NotImplementedError() if optimizer == 'adam': optimizer = optim.Adam(net.parameters()) elif optimizer == 'sgd': optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr, weight_decay=1e-5) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) # net, optimizer = amp.initialize(net, optimizer, opt_level='O1') if torch.cuda.device_count() > 1: logger.info('{} GPUs avaliable'.format(torch.cuda.device_count())) torch.distributed.init_process_group( backend='nccl', init_method='tcp://127.0.0.1:38366', rank=rank, world_size=world_size) if distributed_is_initialized(): logger.info('distributed is initialized') net.to(device) net = nn.parallel.DistributedDataParallel(net) else: logger.info('data parallel') net = nn.DataParallel(net) net.to(device) min_loss = float('Inf') early_stop_count = 0 global_step = 0 start_epoch = 0 start_loss = 0 train_info = { 'train_loss': [], 'val_loss': [], 'label_0_acc': [], 'label_1_acc': [], 'label_2_acc': [], 'label_3_acc': [], 'label_4_acc': [] } if is_resume: try: # open previous check points ckp_path = os.path.join(model_path, '{}_model_ckp.pth.tar'.format(model_name)) net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp( ckp_path, net, optimizer, scheduler) # open previous training records with open(train_info_file) as f: train_info = json.load(f) logger.info( 'Training loss from last time is {}'.format(start_loss) + '\n' + 'Mininum training loss from last time is {}'.format(min_loss)) logger.info( 'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}' .format(train_info['label_0_acc'][-1], train_info['label_1_acc'][-1], train_info['label_2_acc'][-1], train_info['label_3_acc'][-1], train_info['label_4_acc'][-1])) # min_loss = float('Inf') except: logger.warning( 'No checkpoint available, strat training from scratch') # start training for epoch in range(start_epoch, epochs): # every epoch generate a new set of images train_set = data_loader(train_dict, batch_size=batch_size, key='train', num_works=num_workers, time_step=time_step, patch=crop_size, modalities=modalities, model_type='CNN') n_train = len(train_set) train_loader = train_set.load() val_set = data_loader(val_dict, batch_size=batch_size, key='val', num_works=num_workers, time_step=time_step, patch=crop_size, modalities=modalities, model_type='CNN') n_val = len(val_set) nb_batches = np.ceil(n_train / batch_size) n_total = n_train + n_val logger.info( '{} images will be used in total, {} for trainning and {} for validation' .format(n_total, n_train, n_val)) logger.info('Dataset loading finished!') # setup to train mode net.train() running_loss = 0 dice_score_label_0 = 0 dice_score_label_1 = 0 dice_score_label_2 = 0 dice_score_label_3 = 0 dice_score_label_4 = 0 logger.info('Training epoch {} will begin'.format(epoch + 1)) with tqdm(total=n_train, desc=f'Epoch {epoch+1}/{epochs}', unit='patch') as pbar: for i, data in enumerate(train_loader, 0): images, segs = data['image'].to(device), data['seg'].to(device) if model_type == 'SkipDenseSeg' and not softmax: segs = segs.long() # combine the batch and time step batch, time, channel, z, y, x = images.shape images = images.view(-1, channel, z, y, x) segs = segs.view(-1, z, y, x) # zero the parameter gradients optimizer.zero_grad() outputs = net(images) loss = criterion(outputs, segs) loss.backward() # with amp.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() optimizer.step() running_loss += loss.detach().item() _, preds = torch.max(outputs.data, 1) dice_score = dice(preds.data.cpu(), segs.data.cpu(), ignore_idx=ignore_idx) dice_score_label_0 += dice_score['bg'] dice_score_label_1 += dice_score['csf'] dice_score_label_2 += dice_score['gm'] dice_score_label_3 += dice_score['wm'] dice_score_label_4 += dice_score['tm'] # show progress bar pbar.set_postfix( **{ 'Training loss': loss.detach().item(), 'Training accuracy': dice_score['avg'] }) pbar.update(images.shape[0]) del images, segs global_step += 1 if global_step % nb_batches == 0: net.eval() val_loss, val_acc, val_info = validation( net, val_set, criterion, device, batch_size, model_type=model_type, softmax=softmax, ignore_idx=ignore_idx) train_info['train_loss'].append(running_loss / nb_batches) train_info['val_loss'].append(val_loss) train_info['label_0_acc'].append(dice_score_label_0 / nb_batches) train_info['label_1_acc'].append(dice_score_label_1 / nb_batches) train_info['label_2_acc'].append(dice_score_label_2 / nb_batches) train_info['label_3_acc'].append(dice_score_label_3 / nb_batches) train_info['label_4_acc'].append(dice_score_label_4 / nb_batches) # save bast trained model if model_type == 'SkipDenseSeg': scheduler.step() else: scheduler.step(val_loss) # debug for param_group in optimizer.param_groups: logger.info('%0.6f | %6d ' % (param_group['lr'], epoch)) if min_loss > running_loss / nb_batches + 1e-2: min_loss = running_loss / nb_batches is_best = True early_stop_count = 0 else: is_best = False early_stop_count += 1 # save the check point state = { 'epoch': epoch + 1, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': running_loss / nb_batches, 'min_loss': min_loss } verbose = save_ckp(state, is_best, early_stop_count=early_stop_count, early_stop_patience=early_stop_patience, save_model_dir=model_path, best_dir=best_path, name=model_name) # summarize the training results of this epoch logger.info('Average training loss of this epoch is {}'.format( running_loss / nb_batches)) logger.info('Best training loss till now is {}'.format(min_loss)) logger.info('Validation dice loss: {}; Validation accuracy: {}'.format( val_loss, val_acc)) # save the training info every epoch logger.info('Writing the training info into file ...') val_info_file = os.path.join(intermidiate_data_save, '{}_val_info.json'.format(model_name)) with open(train_info_file, 'w') as fp: json.dump(train_info, fp) with open(val_info_file, 'w') as fp: json.dump(val_info, fp) loss_plot(train_info_file, name=model_name) for name, layer in net.named_parameters(): writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(), epoch) writer.add_histogram(name + '_data', layer.cpu().data.numpy(), epoch) if verbose: logger.info( 'The validation loss has not improved for {} epochs, training will stop here.' .format(early_stop_patience)) break writer.close() logger.info('finish training!')