def evaluate_saved_model(model_config, split='validation', model_path=None, data_path=None, save_directory=None, save_nii=False, save_npz=False): # Load options json_opts = json_file_to_pyobj(model_config) train_opts = json_opts.training model_opts = json_opts.model data_path_opts = json_opts.data_path if model_path is not None: model_opts = json_opts.model._replace( path_pre_trained_model=model_path) model_opts = model_opts._replace(gpu_ids=[]) # Setup the NN Model model = get_model(model_opts) if save_directory is None: save_directory = os.path.join(os.path.dirname(model_config), split + '_evaluation') mkdir(save_directory) # Setup Dataset and Augmentation ds_class = get_dataset(train_opts.arch_type) if data_path is None: data_path = get_dataset_path(train_opts.arch_type, data_path_opts) dataset_transform = get_dataset_transformation(train_opts.arch_type, opts=json_opts.augmentation) # Setup channels channels = json_opts.data_opts.channels if len(channels) != json_opts.model.input_nc \ or len(channels) != getattr(json_opts.augmentation, train_opts.arch_type).scale_size[-1]: raise Exception( 'Number of data channels must match number of model channels, and patch and scale size dimensions' ) # Setup Data Loader split_opts = json_opts.data_split dataset = ds_class(data_path, split=split, transform=dataset_transform['valid'], preload_data=train_opts.preloadData, train_size=split_opts.train_size, test_size=split_opts.test_size, valid_size=split_opts.validation_size, split_seed=split_opts.seed, channels=channels) data_loader = DataLoader(dataset=dataset, num_workers=8, batch_size=1, shuffle=False) # Visualisation Parameters # visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir) # Setup stats logger stat_logger = StatLogger() if save_npz: all_predicted = [] # test for iteration, data in tqdm(enumerate(data_loader, 1)): model.set_input(data[0], data[1]) model.test() input_arr = np.squeeze(data[0].cpu().numpy()).astype(np.float32) prior_arr = np.squeeze(data[0].cpu().numpy())[5].astype(np.int16) prior_arr[prior_arr > 0] = 1 label_arr = np.squeeze(data[1].cpu().numpy()).astype(np.int16) ids = dataset.get_ids(data[2]) output_arr = np.squeeze(model.pred_seg.cpu().byte().numpy()).astype( np.int16) # If there is a label image - compute statistics dice_vals = dice_score(label_arr, output_arr, n_class=int(2)) single_class_dice = single_class_dice_score(label_arr, output_arr) md, hd = distance_metric(label_arr, output_arr, dx=2.00, k=1) precision, recall = precision_and_recall(label_arr, output_arr, n_class=int(2)) sp = specificity(label_arr, output_arr) jaccard = jaccard_score(label_arr.flatten(), output_arr.flatten()) # compute stats for the prior that is used prior_dice = single_class_dice_score(label_arr, prior_arr) prior_precision, prior_recall = precision_and_recall(label_arr, prior_arr, n_class=int(2)) stat_logger.update(split=split, input_dict={ 'img_name': ids[0], 'dice_bg': dice_vals[0], 'dice_les': dice_vals[1], 'dice2_les': single_class_dice, 'prec_les': precision[1], 'reca_les': recall[1], 'specificity': sp, 'md_les': md, 'hd_les': hd, 'jaccard': jaccard, 'dice_prior': prior_dice, 'prec_prior': prior_precision[1], 'reca_prior': prior_recall[1] }) if save_nii: # Write a nifti image import SimpleITK as sitk input_img = sitk.GetImageFromArray( np.transpose(input_arr[0], (2, 1, 0))) input_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1]) cbf_img = sitk.GetImageFromArray( np.transpose(input_arr[1], (2, 1, 0))) cbf_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1]) prior_img = sitk.GetImageFromArray( np.transpose(input_arr[5], (2, 1, 0))) prior_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1]) label_img = sitk.GetImageFromArray( np.transpose(label_arr, (2, 1, 0))) label_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1]) predi_img = sitk.GetImageFromArray( np.transpose(output_arr, (2, 1, 0))) predi_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1]) sitk.WriteImage( input_img, os.path.join(save_directory, '{}_img.nii.gz'.format(iteration))) sitk.WriteImage( cbf_img, os.path.join(save_directory, '{}_cbf.nii.gz'.format(iteration))) sitk.WriteImage( prior_img, os.path.join(save_directory, '{}_prior.nii.gz'.format(iteration))) sitk.WriteImage( label_img, os.path.join(save_directory, '{}_lbl.nii.gz'.format(iteration))) sitk.WriteImage( predi_img, os.path.join(save_directory, '{}_pred.nii.gz'.format(iteration))) if save_npz: all_predicted.append(output_arr) stat_logger.statlogger2csv(split=split, out_csv_name=os.path.join( save_directory, split + '_stats.csv')) for key, (mean_val, std_val) in stat_logger.get_errors(split=split).items(): print('-', key, ': \t{0:.3f}+-{1:.3f}'.format(mean_val, std_val), '-') if save_npz: np.savez_compressed(os.path.join(save_directory, 'predictions.npz'), predicted=np.array(all_predicted))
def train(arguments): # Parse input arguments json_filename = arguments.config network_debug = arguments.debug # Load options json_opts = json_file_to_pyobj(json_filename) train_opts = json_opts.training # Architecture type arch_type = train_opts.arch_type # Setup Dataset and Augmentation ds_class = get_dataset(arch_type) ds_path = get_dataset_path(arch_type, json_opts.data_path) ds_transform = get_dataset_transformation( arch_type, opts=json_opts.augmentation, max_output_channels=json_opts.model.output_nc, verbose=json_opts.training.verbose) # Setup channels channels = json_opts.data_opts.channels if len(channels) != json_opts.model.input_nc \ or len(channels) != getattr(json_opts.augmentation, arch_type).scale_size[-1]: raise Exception( 'Number of data channels must match number of model channels, and patch and scale size dimensions' ) # Setup the NN Model model = get_model(json_opts.model) if network_debug: print('# of pars: ', model.get_number_parameters()) print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format( *model.get_fp_bp_time())) exit() # Setup Data Loader split_opts = json_opts.data_split train_dataset = ds_class(ds_path, split='train', transform=ds_transform['train'], preload_data=train_opts.preloadData, train_size=split_opts.train_size, test_size=split_opts.test_size, valid_size=split_opts.validation_size, split_seed=split_opts.seed, channels=channels) valid_dataset = ds_class(ds_path, split='validation', transform=ds_transform['valid'], preload_data=train_opts.preloadData, train_size=split_opts.train_size, test_size=split_opts.test_size, valid_size=split_opts.validation_size, split_seed=split_opts.seed, channels=channels) test_dataset = ds_class(ds_path, split='test', transform=ds_transform['valid'], preload_data=train_opts.preloadData, train_size=split_opts.train_size, test_size=split_opts.test_size, valid_size=split_opts.validation_size, split_seed=split_opts.seed, channels=channels) train_loader = DataLoader(dataset=train_dataset, num_workers=16, batch_size=train_opts.batchSize, shuffle=True) valid_loader = DataLoader(dataset=valid_dataset, num_workers=16, batch_size=train_opts.batchSize, shuffle=False) test_loader = DataLoader(dataset=test_dataset, num_workers=16, batch_size=train_opts.batchSize, shuffle=False) # Visualisation Parameters visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir) error_logger = ErrorLogger() # Training Function model.set_scheduler(train_opts) # Setup Early Stopping early_stopper = EarlyStopper(json_opts.training.early_stopping, verbose=json_opts.training.verbose) for epoch in range(model.which_epoch, train_opts.n_epochs): print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader))) train_volumes = [] validation_volumes = [] # Training Iterations for epoch_iter, (images, labels, indices) in tqdm(enumerate(train_loader, 1), total=len(train_loader)): # Make a training update model.set_input(images, labels) model.optimize_parameters() #model.optimize_parameters_accumulate_grd(epoch_iter) # Error visualisation errors = model.get_current_errors() error_logger.update(errors, split='train') ids = train_dataset.get_ids(indices) volumes = model.get_current_volumes() visualizer.display_current_volumes(volumes, ids, 'train', epoch) train_volumes.append(volumes) # Validation and Testing Iterations for loader, split, dataset in zip([valid_loader, test_loader], ['validation', 'test'], [valid_dataset, test_dataset]): for epoch_iter, (images, labels, indices) in tqdm(enumerate(loader, 1), total=len(loader)): ids = dataset.get_ids(indices) # Make a forward pass with the model model.set_input(images, labels) model.validate() # Error visualisation errors = model.get_current_errors() stats = model.get_segmentation_stats() error_logger.update({**errors, **stats}, split=split) if split == 'validation': # do not look at testing # Visualise predictions volumes = model.get_current_volumes() visualizer.display_current_volumes(volumes, ids, split, epoch) validation_volumes.append(volumes) # Track validation loss values early_stopper.update({**errors, **stats}) # Update the plots for split in ['train', 'validation', 'test']: visualizer.plot_current_errors(epoch, error_logger.get_errors(split), split_name=split) visualizer.print_current_errors(epoch, error_logger.get_errors(split), split_name=split) visualizer.save_plots(epoch, save_frequency=5) error_logger.reset() # Save the model parameters if not early_stopper.is_improving is False: model.save(json_opts.model.model_type, epoch) save_config(json_opts, json_filename, model, epoch) # Update the model learning rate model.update_learning_rate( metric=early_stopper.get_current_validation_loss()) if early_stopper.interrogate(epoch): break