コード例 #1
0
    def __init__(self,
                 model,
                 exp_name,
                 device,
                 num_class,
                 optim=torch.optim.SGD,
                 optim_args={},
                 loss_func=losses.DiceLoss(),
                 model_name='OneShotSegmentor',
                 labels=None,
                 num_epochs=10,
                 log_nth=5,
                 lr_scheduler_step_size=5,
                 lr_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 exp_dir='experiments',
                 log_dir='logs'):

        self.device = device
        self.model = model

        self.model_name = model_name
        self.labels = labels
        self.num_epochs = num_epochs
        if torch.cuda.is_available():
            self.loss_func = loss_func.cuda(device)
        else:
            self.loss_func = loss_func

        self.optim_c = optim(
            [{'params': model.conditioner.parameters(), 'lr': 1e-3, 'momentum': 0.99, 'weight_decay': 0.0001}
             ], **optim_args)

        self.optim_s = optim(
            [{'params': model.segmentor.parameters(), 'lr': 1e-3, 'momentum': 0.99, 'weight_decay': 0.0001}
             ], **optim_args)

        self.scheduler_s = lr_scheduler.StepLR(self.optim_s, step_size=10,
                                               gamma=0.1)
        self.scheduler_c = lr_scheduler.StepLR(self.optim_c, step_size=10,
                                               gamma=0.001)

        exp_dir_path = os.path.join(exp_dir, exp_name)
        common_utils.create_if_not(exp_dir_path)
        common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR))
        self.exp_dir_path = exp_dir_path

        self.log_nth = log_nth
        self.logWriter = LogWriter(
            num_class, log_dir, exp_name, use_last_checkpoint, labels)

        self.use_last_checkpoint = use_last_checkpoint
        self.start_epoch = 1
        self.start_iteration = 1

        self.best_ds_mean = 0
        self.best_ds_mean_epoch = 0

        if use_last_checkpoint:
            self.load_checkpoint()
コード例 #2
0
ファイル: solver_sgd.py プロジェクト: ai-med/AbdomenNet
    def __init__(self,
                 model,
                 exp_name,
                 device,
                 num_class,
                 optim=torch.optim.SGD,
                 optim_args={},
                 loss_func=additional_losses.CombinedLoss(),
                 model_name='quicknat',
                 labels=None,
                 num_epochs=10,
                 log_nth=5,
                 lr_scheduler_step_size=5,
                 lr_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 exp_dir='experiments',
                 log_dir='logs',
                 arch_file_path=None):

        self.device = device
        self.model = model
        # self.swa_model = torch.optim.swa_utils.AveragedModel(self.model)
        self.model_name = model_name
        self.labels = labels
        self.num_epochs = num_epochs
        if torch.cuda.is_available():
            self.loss_func = loss_func.cuda(device)
        else:
            self.loss_func = loss_func
        self.optim = optim(model.parameters(), **optim_args)
        # self.scheduler = lr_scheduler.StepLR(self.optim, step_size=lr_scheduler_step_size,
        #                                      gamma=lr_scheduler_gamma)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim, T_max=100)
        # self.swa_start = -1 #int(np.round(self.num_epochs*0.75))
        # print(self.swa_start)
        # self.swa_scheduler = torch.optim.swa_utils.SWALR(self.optim, swa_lr=0.05)

        exp_dir_path = os.path.join(exp_dir, exp_name)
        common_utils.create_if_not(exp_dir_path)
        common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR))
        self.exp_dir_path = exp_dir_path

        self.save_architectural_files(arch_file_path)
        
        self.log_nth = log_nth
        self.logWriter = LogWriter(num_class, log_dir, exp_name, use_last_checkpoint, labels)
        # self.wandb = wandb

        self.use_last_checkpoint = use_last_checkpoint

        self.start_epoch = 1
        self.start_iteration = 1

        self.best_ds_mean = 0
        self.best_ds_mean_epoch = 0

        if use_last_checkpoint:
            self.load_checkpoint()

        print(self.best_ds_mean, self.best_ds_mean_epoch, self.start_epoch)
コード例 #3
0
def evaluate(coronal_model_path, volumes_txt_file, data_dir, device, prediction_path, batch_size, orientation,
             label_names, dir_struct, need_unc=False, mc_samples=0):
    print("**Starting evaluation**")
    with open(volumes_txt_file) as file_handle:
        volumes_to_use = file_handle.read().splitlines()

    model = torch.load(coronal_model_path)
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.empty_cache()
        model.cuda(device)

    model.eval()

    common_utils.create_if_not(prediction_path)
    print("Evaluating now...")
    file_paths = du.load_file_paths_eval(data_dir, volumes_txt_file, dir_struct)

    with torch.no_grad():
        volume_dict_list = []
        cvs_dict_list = []
        iou_dict_list = []
        for vol_idx, file_path in enumerate(file_paths):
            try:
                if need_unc == "True":
                    _, volume_prediction, mc_pred_list, header = _segment_vol_unc(file_path, model, orientation,
                                                                                  batch_size, mc_samples,
                                                                                  cuda_available, device)
                    iou_dict, cvs_dict = compute_structure_uncertainty(mc_pred_list, label_names,
                                                                       volumes_to_use[vol_idx])
                    cvs_dict_list.append(cvs_dict)
                    iou_dict_list.append(iou_dict)
                else:
                    _, volume_prediction, header = _segment_vol(file_path, model, orientation, batch_size,
                                                                cuda_available,
                                                                device)

                nifti_img = nib.Nifti1Image(volume_prediction, np.eye(4), header=header)
                print("Processed: " + volumes_to_use[vol_idx] + " " + str(vol_idx + 1) + " out of " + str(
                    len(file_paths)))
                nib.save(nifti_img, os.path.join(prediction_path, volumes_to_use[vol_idx] + str('.nii')))
                per_volume_dict = compute_volume(volume_prediction, label_names, volumes_to_use[vol_idx])
                volume_dict_list.append(per_volume_dict)
            except FileNotFoundError:
                print("Error in reading the file ...")

        _write_csv_table('volume_estimates.csv', prediction_path, volume_dict_list, label_names)

        if need_unc == "True":
            _write_csv_table('cvs_uncertainty.csv', prediction_path, cvs_dict_list, label_names)
            _write_csv_table('iou_uncertainty.csv', prediction_path, iou_dict_list, label_names)

    print("DONE")
コード例 #4
0
    def __init__(self,
                 model,
                 exp_name,
                 device,
                 num_class,
                 optim=torch.optim.SGD,
                 optim_args={},
                 loss_func=losses.CombinedLoss(),
                 model_name='segmentor',
                 labels=None,
                 num_epochs=10,
                 log_nth=5,
                 lr_scheduler_step_size=5,
                 lr_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 exp_dir='experiments',
                 log_dir='logs'):

        self.device = device
        self.model = model

        self.model_name = model_name
        self.labels = labels
        self.num_epochs = num_epochs
        if torch.cuda.is_available():
            self.loss_func = loss_func.cuda(device)
        else:
            self.loss_func = loss_func
        self.optim = optim(model.parameters(), **optim_args)
        self.scheduler = lr_scheduler.StepLR(self.optim,
                                             step_size=lr_scheduler_step_size,
                                             gamma=lr_scheduler_gamma)

        exp_dir_path = os.path.join(exp_dir, exp_name)
        common_utils.create_if_not(exp_dir_path)
        common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR))
        self.exp_dir_path = exp_dir_path

        self.log_nth = log_nth
        self.logWriter = LogWriter(num_class, log_dir, exp_name,
                                   use_last_checkpoint, labels)

        self.use_last_checkpoint = use_last_checkpoint

        self.start_epoch = 1
        self.start_iteration = 1

        self.best_ds_mean = 0
        self.best_ds_mean_epoch = 0

        if use_last_checkpoint:
            self.load_checkpoint()
コード例 #5
0
ファイル: solver_sgd.py プロジェクト: ai-med/AbdomenNet
 def save_architectural_files(self, arch_file_paths):
     if arch_file_paths is not None:
         arch_file_path, setting_path = arch_file_paths
         destination = os.path.join(self.exp_dir_path, ARCHITECTURE_DIR)
         common_utils.create_if_not(destination)
         arch_base = "/".join(arch_file_path.split('/')[:-1])
         print(arch_file_path, arch_base, setting_path, destination+'/model.py')
         shutil.copy(arch_file_path, destination+'/model.py')
         shutil.copy(f'{arch_base}/run.py', f'{destination}/run.py')
         shutil.copy(f'{arch_base}/solver.py', f'{destination}/solver.py')
         shutil.copy(f'{arch_base}/utils/evaluator.py', f'{destination}/utils-evaluator.py')
         shutil.copy(f'{arch_base}/nn_common_modules/losses.py', f'{destination}/nn_common_modules-losses.py')
         shutil.copy(f'{arch_base}/nn_common_modules/modules.py', f'{destination}/nn_common_modules-modules.py')
         shutil.copy(f'{setting_path}', f'{destination}/settings.ini')
     else:
         print('No Architectural file!!!')
コード例 #6
0
def evaluate_dice_score(model_path, num_classes, data_dir, label_dir, volumes_txt_file, remap_config, orientation,
                        prediction_path, data_id, device=0, logWriter=None, mode='eval'):
    print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**")

    batch_size = 20

    with open(volumes_txt_file) as file_handle:
        volumes_to_use = file_handle.read().splitlines()

    model = torch.load(model_path)
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.empty_cache()
        model.cuda(device)

    model.eval()

    common_utils.create_if_not(prediction_path)
    volume_dice_score_list = []
    print("Evaluating now...")
    file_paths = du.load_file_paths(data_dir, label_dir, data_id, volumes_txt_file)
    with torch.no_grad():
        for vol_idx, file_path in enumerate(file_paths):
            volume, labelmap, class_weights, weights, header = du.load_and_preprocess(file_path,
                                                                                      orientation=orientation,
                                                                                      remap_config=remap_config)

            volume = volume if len(volume.shape) == 4 else volume[:, np.newaxis, :, :]
            volume, labelmap = torch.tensor(volume).type(torch.FloatTensor), torch.tensor(labelmap).type(
                torch.LongTensor)

            volume_prediction = []
            for i in range(0, len(volume), batch_size):
                batch_x, batch_y = volume[i: i + batch_size], labelmap[i:i + batch_size]
                if cuda_available:
                    batch_x = batch_x.cuda(device)
                out = model(batch_x)
                _, batch_output = torch.max(out, dim=1)
                volume_prediction.append(batch_output)

            volume_prediction = torch.cat(volume_prediction)
            volume_dice_score = dice_score_perclass(volume_prediction, labelmap.cuda(device), num_classes, mode=mode)

            volume_prediction = (volume_prediction.cpu().numpy()).astype('float32')
            nifti_img = nib.MGHImage(np.squeeze(volume_prediction), np.eye(4), header=header)
            nib.save(nifti_img, os.path.join(prediction_path, volumes_to_use[vol_idx] + str('.mgz')))
            if logWriter:
                logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx], vol_idx)

            volume_dice_score = volume_dice_score.cpu().numpy()
            volume_dice_score_list.append(volume_dice_score)
            print(volume_dice_score, np.mean(volume_dice_score))
        dice_score_arr = np.asarray(volume_dice_score_list)
        avg_dice_score = np.mean(dice_score_arr)
        print("Mean of dice score : " + str(avg_dice_score))
        class_dist = [dice_score_arr[:, c] for c in range(num_classes)]

        if logWriter:
            logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score')
    print("DONE")

    return avg_dice_score, class_dist
コード例 #7
0
def evaluate2view(coronal_model_path, axial_model_path, volumes_txt_file, data_dir, device, prediction_path, batch_size,
                  label_names, dir_struct, need_unc=False, mc_samples=0):
    print("**Starting evaluation**")
    with open(volumes_txt_file) as file_handle:
        volumes_to_use = file_handle.read().splitlines()

    model1 = torch.load(coronal_model_path)

    model2 = torch.load(axial_model_path)

    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.empty_cache()
        model1.cuda(device)
        model2.cuda(device)

    model1.eval()
    model2.eval()

    common_utils.create_if_not(prediction_path)
    print("Evaluating now...")

    file_paths = du.load_file_paths_eval(data_dir, volumes_txt_file, dir_struct)

    with torch.no_grad():
        volume_dict_list = []
        cvs_dict_list = []
        iou_dict_list = []
        for vol_idx, file_path in enumerate(file_paths):
            try:
                if need_unc == "True":
                    volume_prediction_cor, _, mc_pred_list_cor, header = _segment_vol_unc(file_path, model1, "COR",
                                                                                          batch_size, mc_samples,
                                                                                          cuda_available, device)
                    volume_prediction_axi, _, mc_pred_list_axi, header = _segment_vol_unc(file_path, model2, "AXI",
                                                                                          batch_size, mc_samples,
                                                                                          cuda_available, device)
                    mc_pred_list = mc_pred_list_cor + mc_pred_list_axi
                    iou_dict, cvs_dict = compute_structure_uncertainty(mc_pred_list, label_names,
                                                                       volumes_to_use[vol_idx])
                    cvs_dict_list.append(cvs_dict)
                    iou_dict_list.append(iou_dict)
                else:
                    volume_prediction_cor, _, header = _segment_vol(file_path, model1, "COR", batch_size,
                                                                    cuda_available,
                                                                    device)
                    volume_prediction_axi, _, header = _segment_vol(file_path, model2, "AXI", batch_size,
                                                                    cuda_available,
                                                                    device)

                _, volume_prediction = torch.max(volume_prediction_axi + volume_prediction_cor, dim=1)
                volume_prediction = (volume_prediction.cpu().numpy()).astype('float32')
                volume_prediction = np.squeeze(volume_prediction)
                nifti_img = nib.Nifti1Image(volume_prediction, np.eye(4), header=header)
                print("Processed: " + volumes_to_use[vol_idx] + " " + str(vol_idx + 1) + " out of " + str(
                    len(file_paths)))
                nib.save(nifti_img, os.path.join(prediction_path, volumes_to_use[vol_idx] + str('.nii.gz')))

                per_volume_dict = compute_volume(volume_prediction, label_names, volumes_to_use[vol_idx])
                volume_dict_list.append(per_volume_dict)

            except FileNotFoundError:
                print("Error in reading the file ...")
            except Exception as exp:
                import logging
                logging.getLogger(__name__).exception(exp)
                # print("Other kind o error!")

        _write_csv_table('volume_estimates.csv', prediction_path, volume_dict_list, label_names)

        if need_unc == "True":
            _write_csv_table('cvs_uncertainty.csv', prediction_path, cvs_dict_list, label_names)
            _write_csv_table('iou_uncertainty.csv', prediction_path, iou_dict_list, label_names)

    print("DONE")
コード例 #8
0
def evaluate_dice_score(model_path, num_classes, data_dir, label_dir, volumes_txt_file, remap_config, orientation,
                        prediction_path, data_id, device=0, logWriter=None, mode='eval'):
    log.info("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**")

    #batch_size = 20 #BORIS: does not fit in memory
    batch_size = 10

    with open(volumes_txt_file) as file_handle:
        volumes_to_use = file_handle.read().splitlines()

    cuda_available = torch.cuda.is_available()
    # First, are we attempting to run on a GPU?
    if type(device) == int:
        # if CUDA available, follow through, else warn and fallback to CPU
        if cuda_available:
            model = torch.load(model_path)
            torch.cuda.empty_cache()
            model.cuda(device)
        else:
            log.warning(
                'CUDA is not available, trying with CPU.' + \
                'This can take much longer (> 1 hour). Cancel and ' + \
                'investigate if this behavior is not desired.'
            )
            # switch device to 'cpu'
            device = 'cpu'
    # If device is 'cpu' or CUDA not available
    if (type(device)==str) or not cuda_available:
        model = torch.load(
            model_path, 
            map_location=torch.device(device)
        )

    model.eval()

    common_utils.create_if_not(prediction_path)
    volume_dice_score_list = []
    log.info("Evaluating now...")
    file_paths = du.load_file_paths(data_dir, label_dir, data_id, volumes_txt_file)
    with torch.no_grad():
        for vol_idx, file_path in enumerate(file_paths):
            volume, labelmap, class_weights, weights, header = du.load_and_preprocess(file_path,
                                                                                      orientation=orientation,
                                                                                      remap_config=remap_config)

            volume = volume if len(volume.shape) == 4 else volume[:, np.newaxis, :, :]
            volume, labelmap = torch.tensor(volume).type(torch.FloatTensor), torch.tensor(labelmap).type(
                torch.LongTensor)

            volume_prediction = []
            for i in range(0, len(volume), batch_size):
                batch_x, batch_y = volume[i: i + batch_size], labelmap[i:i + batch_size]
                if cuda_available and (type(device)==int):
                    batch_x = batch_x.cuda(device)
                out = model(batch_x)
                _, batch_output = torch.max(out, dim=1)
                volume_prediction.append(batch_output)

            volume_prediction = torch.cat(volume_prediction)
            volume_dice_score = dice_score_perclass(volume_prediction, labelmap.cuda(device), num_classes, mode=mode)

            volume_prediction = (volume_prediction.cpu().numpy()).astype('float32')

            #Copy header affine
            Mat = np.array([
                header['srow_x'], 
                header['srow_y'], 
                header['srow_z'],
                [0,0,0,1]
            ])

            volume_prediction = np.squeeze(volume_prediction)

            volume_prediction = preprocessor.remap_labels_back(volume_prediction, remap_config) #BORIS

            #BORIS
            if orientation == "COR":
                volume_prediction = volume_prediction.transpose((1, 2, 0))
            elif orientation == "AXI":
                volume_prediction = volume_prediction.transpose((2, 0, 1))

            # Apply original image affine to prediction volume
            #nifti_img = nib.MGHImage(np.squeeze(volume_prediction), Mat, header=header)
            nifti_img = nib.Nifti1Image(volume_prediction, Mat, header=header)  #BORIS
            #nib.save(nifti_img, os.path.join(prediction_path, volumes_to_use[vol_idx] + str('.mgz'))) #BORIS
            outputfilename = os.path.join(prediction_path, os.path.basename(file_path[0]).replace(".nii", "_seg1.nii")) #BORIS
            nib.save(nifti_img, outputfilename)

            if logWriter:
                logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx], vol_idx)

            volume_dice_score = volume_dice_score.cpu().numpy()
            volume_dice_score_list.append(volume_dice_score)
            log.info(volume_dice_score, np.mean(volume_dice_score))
        dice_score_arr = np.asarray(volume_dice_score_list)
        avg_dice_score = np.mean(dice_score_arr)
        log.info("Mean of dice score : " + str(avg_dice_score))
        class_dist = [dice_score_arr[:, c] for c in range(num_classes)]

        if logWriter:
            logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score')
    log.info("DONE")

    return avg_dice_score, class_dist
コード例 #9
0
def evaluate2view(coronal_model_path, axial_model_path, volumes_txt_file, data_dir, device, prediction_path, batch_size,
                  label_names, dir_struct, need_unc=False, mc_samples=0, exit_on_error=False):
    log.info("**Starting evaluation**")
    with open(volumes_txt_file) as file_handle:
        volumes_to_use = file_handle.read().splitlines()

    cuda_available = torch.cuda.is_available()
    if type(device) == int:
        # if CUDA available, follow through, else warn and fallback to CPU
        if cuda_available:
            model1 = torch.load(coronal_model_path)
            model2 = torch.load(axial_model_path)
            
            torch.cuda.empty_cache()
            model1.cuda(device)
            model2.cuda(device)
        else:
            log.warning(
                'CUDA is not available, trying with CPU.' + \
                'This can take much longer (> 1 hour). Cancel and ' + \
                'investigate if this behavior is not desired.'
            )

    if (type(device)==str) or not cuda_available:
        model1 = torch.load(
            coronal_model_path, 
            map_location=torch.device(device)
        )
        model2 = torch.load(
            axial_model_path, 
            map_location=torch.device(device)
        )

    model1.eval()
    model2.eval()

    common_utils.create_if_not(prediction_path)
    log.info("Evaluating now...")

    file_paths = du.load_file_paths_eval(data_dir, volumes_txt_file, dir_struct)

    with torch.no_grad():
        volume_dict_list = []
        cvs_dict_list = []
        iou_dict_list = []
        for vol_idx, file_path in enumerate(file_paths):
            try:
                if need_unc == "True":
                    volume_prediction_cor, _, mc_pred_list_cor, header = _segment_vol_unc(file_path, model1, "COR",
                                                                                          batch_size, mc_samples,
                                                                                          cuda_available, device)
                    volume_prediction_axi, _, mc_pred_list_axi, header = _segment_vol_unc(file_path, model2, "AXI",
                                                                                          batch_size, mc_samples,
                                                                                          cuda_available, device)
                    mc_pred_list = mc_pred_list_cor + mc_pred_list_axi
                    iou_dict, cvs_dict = compute_structure_uncertainty(mc_pred_list, label_names,
                                                                       volumes_to_use[vol_idx])
                    cvs_dict_list.append(cvs_dict)
                    iou_dict_list.append(iou_dict)
                else:
                    volume_prediction_cor, _, header = _segment_vol(file_path, model1, "COR", batch_size,
                                                                    cuda_available,
                                                                    device)
                    volume_prediction_axi, _, header = _segment_vol(file_path, model2, "AXI", batch_size,
                                                                    cuda_available,
                                                                    device)

                _, volume_prediction = torch.max(volume_prediction_axi + volume_prediction_cor, dim=1)
                volume_prediction = (volume_prediction.cpu().numpy()).astype('float32')
                volume_prediction = np.squeeze(volume_prediction)

                #Copy header affine
                Mat = np.array([
                    header['srow_x'], 
                    header['srow_y'], 
                    header['srow_z'],
                    [0,0,0,1]
                ])
                # Apply original image affine to prediction volume
                nifti_img = nib.Nifti1Image(volume_prediction, Mat, header=header)

                log.info("Processed: " + volumes_to_use[vol_idx] + " " + str(vol_idx + 1) + " out of " + str(
                    len(file_paths)))
                nib.save(nifti_img, os.path.join(prediction_path, volumes_to_use[vol_idx] + str('.nii.gz')))

                per_volume_dict = compute_volume(volume_prediction, label_names, volumes_to_use[vol_idx])
                volume_dict_list.append(per_volume_dict)

            except FileNotFoundError as exp:
                log.error("Error in reading the file ...")
                log.exception(exp)
                if exit_on_error:
                    raise(exp)                
            except Exception as exp:
                log.exception(exp)
                if exit_on_error:
                    raise(exp)
                # log.info("Other kind o error!")

        _write_csv_table('volume_estimates.csv', prediction_path, volume_dict_list, label_names)

        if need_unc == "True":
            _write_csv_table('cvs_uncertainty.csv', prediction_path, cvs_dict_list, label_names)
            _write_csv_table('iou_uncertainty.csv', prediction_path, iou_dict_list, label_names)

    log.info("DONE")
コード例 #10
0
def evaluate(coronal_model_path, volumes_txt_file, data_dir, device, prediction_path, batch_size, orientation,
             label_names, dir_struct, need_unc=False, mc_samples=0, exit_on_error=False):
    log.info("**Starting evaluation**")
    with open(volumes_txt_file) as file_handle:
        volumes_to_use = file_handle.read().splitlines()

    cuda_available = torch.cuda.is_available()
    # First, are we attempting to run on a GPU?
    if type(device) == int:
        # if CUDA available, follow through, else warn and fallback to CPU
        if cuda_available:
            model = torch.load(coronal_model_path)
            torch.cuda.empty_cache()
            model.cuda(device)
        else:
            log.warning(
                'CUDA is not available, trying with CPU. ' + \
                'This can take much longer (> 1 hour). Cancel and ' + \
                'investigate if this behavior is not desired.'
            )
            # switch device to 'cpu'
            device = 'cpu'
    # If device is 'cpu' or CUDA not available
    if (type(device)==str) or not cuda_available:
        model = torch.load(
            coronal_model_path, 
            map_location=torch.device(device)
        )

    model.eval()

    common_utils.create_if_not(prediction_path)
    log.info("Evaluating now...")
    file_paths = du.load_file_paths_eval(data_dir, volumes_txt_file, dir_struct)

    with torch.no_grad():
        volume_dict_list = []
        cvs_dict_list = []
        iou_dict_list = []
        for vol_idx, file_path in enumerate(file_paths):
            try:
                if need_unc == "True":
                    _, volume_prediction, mc_pred_list, header = _segment_vol_unc(file_path, model, orientation,
                                                                                  batch_size, mc_samples,
                                                                                  cuda_available, device)
                    iou_dict, cvs_dict = compute_structure_uncertainty(mc_pred_list, label_names,
                                                                       volumes_to_use[vol_idx])
                    cvs_dict_list.append(cvs_dict)
                    iou_dict_list.append(iou_dict)
                else:
                    _, volume_prediction, header = _segment_vol(file_path, model, orientation, batch_size,
                                                                cuda_available,
                                                                device)

                volume_prediction = preprocessor.remap_labels_back(volume_prediction, remap_config='SLANT') #BORIS

                #Copy header affine
                Mat = np.array([
                    header['srow_x'], 
                    header['srow_y'], 
                    header['srow_z'],
                    [0,0,0,1]
                ])
                # Apply original image affine to prediction volume
                nifti_img = nib.Nifti1Image(volume_prediction, Mat, header=header)
                log.info("Processed: " + volumes_to_use[vol_idx] + " " + str(vol_idx + 1) + " out of " + str(
                    len(file_paths)))
                save_file = os.path.join(prediction_path, volumes_to_use[vol_idx])
                if '.nii' not in save_file:
                    save_file += '.nii.gz'
                nib.save(nifti_img, save_file)
                per_volume_dict = compute_volume(volume_prediction, label_names, volumes_to_use[vol_idx])
                volume_dict_list.append(per_volume_dict)
            except FileNotFoundError as exp:
                log.error("Error in reading the file ...")
                log.exception(exp)
                if exit_on_error:
                    raise(exp)
            except Exception as exp:
                log.exception(exp)
                if exit_on_error:
                    raise(exp)

        _write_csv_table('volume_estimates.csv', prediction_path, volume_dict_list, label_names)

        if need_unc == "True":
            _write_csv_table('cvs_uncertainty.csv', prediction_path, cvs_dict_list, label_names)
            _write_csv_table('iou_uncertainty.csv', prediction_path, iou_dict_list, label_names)

    log.info("DONE")
コード例 #11
0
def evaluate_dice_score(model_path,
                        num_classes,
                        query_labels,
                        data_dir,
                        query_txt_file,
                        support_txt_file,
                        remap_config,
                        orientation,
                        prediction_path, device=0, logWriter=None, mode='eval', fold=None):
    print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**")
    print("Loading model => " + model_path)
    batch_size = 20
    Num_support = 10

    with open(query_txt_file) as file_handle:
        volumes_query = file_handle.read().splitlines()

    # with open(support_txt_file) as file_handle:
    #     volumes_support = file_handle.read().splitlines()

    model = torch.load(model_path)
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.empty_cache()
        model.cuda(device)

    model.eval()

    common_utils.create_if_not(prediction_path)

    print("Evaluating now... " + fold)
    query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file)
    support_file_paths = du.load_file_paths(data_dir, data_dir, support_txt_file)

    with torch.no_grad():
        all_query_dice_score_list = []
        for query_label in query_labels:
            volume_dice_score_list = []
            #
            # support_volume, support_labelmap, _, _ = du.load_and_preprocess(support_file_paths[0],
            #                                                                 orientation=orientation,
            #                                                                 remap_config=remap_config)
            #
            # support_volume = support_volume if len(support_volume.shape) == 4 else support_volume[:, np.newaxis, :, :]
            #
            # support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), torch.tensor(
            #     support_labelmap).type(torch.LongTensor)
            # support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label)
            # support_volume = support_volume[range_index[0]: range_index[1]]

            # Loading support
            support_volume, support_labelmap, _, _ = du.load_and_preprocess(support_file_paths[0],
                                                                            orientation=orientation,
                                                                            remap_config=remap_config)
            support_volume = support_volume if len(support_volume.shape) == 4 else support_volume[:, np.newaxis, :,
                                                                                   :]
            support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \
                                               torch.tensor(support_labelmap).type(torch.LongTensor)

            support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label)

            support_slice_indexes = np.round(np.linspace(0, len(support_volume) - 1, Num_support + 1)).astype(int)
            support_slice_indexes += (len(support_volume) // Num_support) // 2
            support_slice_indexes = support_slice_indexes[:-1]
            # support_slice_indexes[0] += (len(support_volume) // Num_support) // 2

            # if len(support_slice_indexes) > 1:
            #     support_slice_indexes[-1] -= (len(support_volume) // Num_support) // 2

            if len(support_slice_indexes) < Num_support:
                support_slice_indexes.append(len(support_volume) - 1)

            # batch_needed = Num_support < 5

            for vol_idx, file_path in enumerate(query_file_paths):

                query_volume, query_labelmap, _, _ = du.load_and_preprocess(file_path,
                                                                            orientation=orientation,
                                                                            remap_config=remap_config)

                query_volume = query_volume if len(query_volume.shape) == 4 else query_volume[:, np.newaxis, :, :]
                query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), \
                                               torch.tensor(query_labelmap).type(torch.LongTensor)

                query_labelmap = query_labelmap == query_label
                range_query = get_range(query_labelmap)
                query_volume = query_volume[range_query[0]: range_query[1] + 1]
                query_labelmap = query_labelmap[range_query[0]: range_query[1] + 1]

                query_slice_indexes = np.round(np.linspace(0, len(query_volume) - 1, Num_support)).astype(int)
                if len(query_slice_indexes) < Num_support:
                    query_slice_indexes.append(len(query_volume) - 1)

                volume_prediction = []

                # for i in range(0, len(query_volume), batch_size):
                # support_current_slice = 0
                # query_current_slice = 0

                for i, query_start_slice in enumerate(query_slice_indexes):
                    if query_start_slice == query_slice_indexes[-1]:
                        query_batch_x = query_volume[query_slice_indexes[i]:]
                    else:
                        query_batch_x = query_volume[query_slice_indexes[i]:query_slice_indexes[i + 1]]

                    support_batch_x = support_volume[support_slice_indexes[i]]

                    # Running larger blocks in smaller batches
                    # if batch_needed:
                    volume_prediction_10 = []
                    for b in range(0, len(query_batch_x), 10):
                        query_batch_x_10 = query_batch_x[b:b + 10]
                        support_batch_x_10 = support_batch_x.repeat(len(query_batch_x_10), 1, 1, 1)
                        if cuda_available:
                            query_batch_x_10 = query_batch_x_10.cuda(device)
                            support_batch_x_10 = support_batch_x_10.cuda(device)

                        weights_10 = model.conditioner(support_batch_x_10)
                        out_10 = model.segmentor(query_batch_x_10, weights_10)

                        # For shaban et al
                        # batch_output_10 = out_10 > 0.5
                        # batch_output_10 = batch_output_10.squeeze()
                        # For others
                        _, batch_output_10 = torch.max(F.softmax(out_10, dim=1), dim=1)

                        volume_prediction_10.append(batch_output_10)
                    volume_prediction.extend(volume_prediction_10)

                    # else:
                    #     support_batch_x = support_batch_x.repeat(len(query_batch_x), 1, 1, 1)
                    #     if cuda_available:
                    #         query_batch_x = query_batch_x.cuda(device)
                    #         support_batch_x = support_batch_x.cuda(device)
                    #
                    #     weights = model.conditioner(support_batch_x)
                    #     out = model.segmentor(query_batch_x, weights)
                    #
                    #     _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                    #     volume_prediction.append(batch_output)

                    # query_current_slice += slice_gap_query
                    # support_current_slice += slice_gap_support

                # query_volume, query_labelmap, _, _ = du.load_and_preprocess(file_path, orientation=orientation,
                #                                                             remap_config=remap_config)
                # query_labelmap = query_labelmap == query_label
                # range_query = get_range(query_labelmap)
                # query_volume = query_volume[range_query[0]: range_query[1]]
                #
                # query_volume = query_volume if len(query_volume.shape) == 4 else query_volume[:, np.newaxis, :, :]
                # query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), torch.tensor(
                #     query_labelmap).type(torch.LongTensor)
                #
                # support_batch_x = []
                #
                # volume_prediction = []
                #
                # support_current_slice = 0
                # query_current_slice = 0
                # support_slice_left = support_volume[range_index[0]]

                # for i in range(0, range_index[0], batch_size):
                #     end_index_query = query_current_slice + batch_size
                #     end_index_query = end_index_query if end_index_query < range_index[0] else range_index[0]
                #
                #     query_batch_x = query_volume[i: end_index_query]
                #
                #     support_batch_x = support_slice_left.repeat(query_batch_x.size()[0], 1, 1, 1)
                #
                #     if cuda_available:
                #         query_batch_x = query_batch_x.cuda(device)
                #         support_batch_x = support_batch_x.cuda(device)
                #
                #     weights = model.conditioner(support_batch_x)
                #     out = model.segmentor(query_batch_x, weights)
                #
                #     _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                #     volume_prediction.append(batch_output)
                #     query_current_slice = end_index_query
                #     support_current_slice = query_current_slice
                #
                # for i in range(range_index[0], range_index[1] + 1, batch_size):
                #     end_index_query = query_current_slice + batch_size
                #     end_index_query = end_index_query if end_index_query < range_index[1] + 1 else range_index[1] + 1
                #
                #     query_batch_x = query_volume[i: end_index_query]
                #
                #     # end_index_support = support_current_slice + batch_size
                #     # end_index_support = end_index_support if end_index_support < len(range_index[1] + 1) else len(
                #     #     range_index[1] + 1)
                #     # print(len(support_volume))
                #     # print(support_current_slice, end_index_query)
                #     support_batch_x = support_volume[support_current_slice: end_index_query]
                #
                #     query_current_slice = end_index_query
                #     support_current_slice = query_current_slice
                #
                #     support_batch_x = support_batch_x[0].repeat(query_batch_x.size()[0], 1, 1, 1)
                #
                #     # k += 1
                #     if cuda_available:
                #         query_batch_x = query_batch_x.cuda(device)
                #         support_batch_x = support_batch_x.cuda(device)
                #
                #     weights = model.conditioner(support_batch_x)
                #     out = model.segmentor(query_batch_x, weights)
                #
                #     _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                #     volume_prediction.append(batch_output)
                #
                # support_slice_right = support_volume[range_index[1]]
                # for i in range(range_index[1] + 1, len(support_volume), batch_size):
                #     end_index_query = query_current_slice + batch_size
                #     end_index_query = end_index_query if end_index_query < len(support_volume) else len(support_volume)
                #
                #     query_batch_x = query_volume[i: end_index_query]
                #
                #     support_batch_x = support_slice_right.repeat(query_batch_x.size()[0], 1, 1, 1)
                #
                #     if cuda_available:
                #         query_batch_x = query_batch_x.cuda(device)
                #         support_batch_x = support_batch_x.cuda(device)
                #
                #     weights = model.conditioner(support_batch_x)
                #     out = model.segmentor(query_batch_x, weights)
                #
                #     _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                #     volume_prediction.append(batch_output)
                #     query_current_slice = end_index_query
                #     support_current_slice = query_current_slice

                volume_prediction = torch.cat(volume_prediction)
                # volume_prediction = volume_prediction.squeeze()
                # batch, _, _ = query_labelmap.size()
                # slice_with_class = torch.sum(query_labelmap.view(batch, -1), dim=1) > 10
                # index = slice_with_class[:-1] - slice_with_class[1:] > 0
                # seq = torch.Tensor(range(batch - 1))
                # range_index_gt = seq[index].type(torch.LongTensor)

                volume_dice_score = dice_score_binary(volume_prediction[:len(query_labelmap)],
                                                      query_labelmap.cuda(device), phase=mode)

                volume_prediction = (volume_prediction.cpu().numpy()).astype('float32')
                nifti_img = nib.MGHImage(np.squeeze(volume_prediction), np.eye(4))
                nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_' + fold + str('.mgz')))
                #
                # # # Save Input
                nifti_img = nib.MGHImage(np.squeeze(query_volume.cpu().numpy()), np.eye(4))
                nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_Input_' + str('.mgz')))

                # # # Condition Input
                nifti_img = nib.MGHImage(np.squeeze(support_volume.cpu().numpy()), np.eye(4))
                nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInput_' + str('.mgz')))
                # Cond GT
                nifti_img = nib.MGHImage(np.squeeze(support_labelmap.cpu().numpy()).astype('float32'), np.eye(4))
                nib.save(nifti_img,
                         os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInputGT_' + str('.mgz')))

                # # # Save Ground Truth
                nifti_img = nib.MGHImage(np.squeeze(query_labelmap.cpu().numpy()), np.eye(4))
                nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_GT_' + fold
                                                 + str('.mgz')))

                # if logWriter:
                #     logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx],
                #                               vol_idx)
                volume_dice_score = volume_dice_score.item()
                volume_dice_score_list.append(volume_dice_score)

                print(volume_dice_score)

            print(volume_dice_score_list)
            dice_score_arr = np.asarray(volume_dice_score_list)
            avg_dice_score = np.median(dice_score_arr)
            print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score))
            all_query_dice_score_list.append(avg_dice_score)
        # class_dist = [dice_score_arr[:, c] for c in range(num_classes)]

        # if logWriter:
        #     logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score')
    print("DONE")

    return np.mean(all_query_dice_score_list)
コード例 #12
0
def evaluate_dice_score_3view(model1_path,
                              model2_path,
                              model3_path,
                              num_classes,
                              query_labels,
                              data_dir,
                              query_txt_file,
                              support_txt_file,
                              remap_config,
                              orientation1,
                              prediction_path, device=0, logWriter=None, mode='eval', fold=None):
    print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**")
    print("Loading model => " + model1_path + " and " + model2_path)
    batch_size = 10

    with open(query_txt_file) as file_handle:
        volumes_query = file_handle.read().splitlines()

    # with open(support_txt_file) as file_handle:
    #     volumes_support = file_handle.read().splitlines()

    model1 = torch.load(model1_path)
    model2 = torch.load(model2_path)
    model3 = torch.load(model3_path)
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.empty_cache()
        model1.cuda(device)
        model2.cuda(device)
        model3.cuda(device)

    model1.eval()
    model2.eval()
    model3.eval()

    common_utils.create_if_not(prediction_path)

    print("Evaluating now... " + fold)
    query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file)
    support_file_paths = du.load_file_paths(data_dir, data_dir, support_txt_file)

    with torch.no_grad():
        all_query_dice_score_list = []
        for query_label in query_labels:
            volume_dice_score_list = []
            for vol_idx, file_path in enumerate(support_file_paths):
                # Loading support
                support_volume1, support_labelmap1, _, _ = du.load_and_preprocess(file_path,
                                                                                  orientation=orientation1,
                                                                                  remap_config=remap_config)
                support_volume2, support_labelmap2 = support_volume1.transpose((1, 2, 0)), support_labelmap1.transpose(
                    (1, 2, 0))

                support_volume3, support_labelmap3 = support_volume1.transpose((2, 0, 1)), support_labelmap1.transpose(
                    (2, 0, 1))

                support_volume1 = support_volume1 if len(support_volume1.shape) == 4 else support_volume1[:, np.newaxis,
                                                                                          :, :]
                support_volume2 = support_volume2 if len(support_volume2.shape) == 4 else support_volume2[:, np.newaxis,
                                                                                          :, :]

                support_volume3 = support_volume3 if len(support_volume3.shape) == 4 else support_volume3[:, np.newaxis,
                                                                                          :, :]

                support_volume1, support_labelmap1 = torch.tensor(support_volume1).type(
                    torch.FloatTensor), torch.tensor(
                    support_labelmap1).type(torch.LongTensor)
                support_volume2, support_labelmap2 = torch.tensor(support_volume2).type(
                    torch.FloatTensor), torch.tensor(
                    support_labelmap2).type(torch.LongTensor)
                support_volume3, support_labelmap3 = torch.tensor(support_volume3).type(
                    torch.FloatTensor), torch.tensor(
                    support_labelmap3).type(torch.LongTensor)

                support_volume1 = binarize_label(support_volume1, support_labelmap1, query_label)
                support_volume2 = binarize_label(support_volume2, support_labelmap2, query_label)
                support_volume3 = binarize_label(support_volume3, support_labelmap3, query_label)

            for vol_idx, file_path in enumerate(query_file_paths):
                query_volume1, query_labelmap1, _, _ = du.load_and_preprocess(file_path,
                                                                              orientation=orientation1,
                                                                              remap_config=remap_config)
                query_volume2, query_labelmap2 = query_volume1.transpose((1, 2, 0)), query_labelmap1.transpose(
                    (1, 2, 0))
                query_volume3, query_labelmap3 = query_volume1.transpose((2, 0, 1)), query_labelmap1.transpose(
                    (2, 0, 1))

                query_volume1 = query_volume1 if len(query_volume1.shape) == 4 else query_volume1[:, np.newaxis, :, :]
                query_volume2 = query_volume2 if len(query_volume2.shape) == 4 else query_volume2[:, np.newaxis, :, :]
                query_volume3 = query_volume3 if len(query_volume3.shape) == 4 else query_volume3[:, np.newaxis, :, :]

                query_volume1, query_labelmap1 = torch.tensor(query_volume1).type(torch.FloatTensor), torch.tensor(
                    query_labelmap1).type(torch.LongTensor)
                query_volume2, query_labelmap2 = torch.tensor(query_volume2).type(torch.FloatTensor), torch.tensor(
                    query_labelmap2).type(torch.LongTensor)
                query_volume3, query_labelmap3 = torch.tensor(query_volume3).type(torch.FloatTensor), torch.tensor(
                    query_labelmap3).type(torch.LongTensor)

                query_labelmap1 = query_labelmap1 == query_label
                # query_labelmap2 = query_labelmap2 == query_label
                # query_labelmap3 = query_labelmap3 == query_label

                # Evaluate for orientation 1
                support_batch_x = []
                k = 2
                volume_prediction1 = []
                for i in range(0, len(query_volume1), batch_size):
                    query_batch_x = query_volume1[i: i + batch_size]
                    if k % 2 == 0:
                        support_batch_x = support_volume1[i: i + batch_size]
                    sz = query_batch_x.size()
                    support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1)
                    k += 1
                    if cuda_available:
                        query_batch_x = query_batch_x.cuda(device)
                        support_batch_x = support_batch_x.cuda(device)

                    weights = model1.conditioner(support_batch_x)
                    out = model1.segmentor(query_batch_x, weights)

                    # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                    volume_prediction1.append(out)

                # Evaluate for orientation 2
                support_batch_x = []
                k = 2
                volume_prediction2 = []
                for i in range(0, len(query_volume2), batch_size):
                    query_batch_x = query_volume2[i: i + batch_size]
                    if k % 2 == 0:
                        support_batch_x = support_volume2[i: i + batch_size]
                    sz = query_batch_x.size()
                    support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1)
                    k += 1
                    if cuda_available:
                        query_batch_x = query_batch_x.cuda(device)
                        support_batch_x = support_batch_x.cuda(device)

                    weights = model2.conditioner(support_batch_x)
                    out = model2.segmentor(query_batch_x, weights)
                    volume_prediction2.append(out)

                # Evaluate for orientation 3
                support_batch_x = []
                k = 2
                volume_prediction3 = []
                for i in range(0, len(query_volume3), batch_size):
                    query_batch_x = query_volume3[i: i + batch_size]
                    if k % 2 == 0:
                        support_batch_x = support_volume3[i: i + batch_size]
                    sz = query_batch_x.size()
                    support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1)
                    k += 1
                    if cuda_available:
                        query_batch_x = query_batch_x.cuda(device)
                        support_batch_x = support_batch_x.cuda(device)

                    weights = model3.conditioner(support_batch_x)
                    out = model3.segmentor(query_batch_x, weights)
                    volume_prediction3.append(out)

                volume_prediction1 = torch.cat(volume_prediction1)
                volume_prediction2 = torch.cat(volume_prediction2)
                volume_prediction3 = torch.cat(volume_prediction3)
                volume_prediction = 0.33 * F.softmax(volume_prediction1, dim=1) + 0.33 * F.softmax(
                    volume_prediction2.permute(3, 1, 0, 2), dim=1) + 0.33 * F.softmax(
                    volume_prediction3.permute(2, 1, 3, 0), dim=1)
                _, batch_output = torch.max(volume_prediction, dim=1)
                volume_dice_score = dice_score_binary(batch_output, query_labelmap1.cuda(device), phase=mode)

                batch_output = (batch_output.cpu().numpy()).astype('float32')
                nifti_img = nib.MGHImage(np.squeeze(batch_output), np.eye(4))
                nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_' + fold + str('.mgz')))

                # # Save Input
                # nifti_img = nib.MGHImage(np.squeeze(query_volume1.cpu().numpy()), np.eye(4))
                # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_Input_' + str('.mgz')))
                # # # Condition Input
                # nifti_img = nib.MGHImage(np.squeeze(support_volume1.cpu().numpy()), np.eye(4))
                # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInput_' + str('.mgz')))
                # # # Cond GT
                # nifti_img = nib.MGHImage(np.squeeze(support_labelmap1.cpu().numpy()).astype('float32'), np.eye(4))
                # nib.save(nifti_img,
                #          os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInputGT_' + str('.mgz')))
                # # # # Save Ground Truth
                # nifti_img = nib.MGHImage(np.squeeze(query_labelmap1.cpu().numpy()), np.eye(4))
                # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_GT_' + str('.mgz')))

                # if logWriter:
                #     logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx],
                #                               vol_idx)
                volume_dice_score = volume_dice_score.cpu().numpy()
                volume_dice_score_list.append(volume_dice_score)

                print(volume_dice_score)

            dice_score_arr = np.asarray(volume_dice_score_list)
            avg_dice_score = np.median(dice_score_arr)
            print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score))
            all_query_dice_score_list.append(avg_dice_score)
        # class_dist = [dice_score_arr[:, c] for c in range(num_classes)]

        # if logWriter:
        #     logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score')
    print("DONE")

    return np.mean(all_query_dice_score_list)
コード例 #13
0
def evaluate_dice_score(model_path,
                        num_classes,
                        query_labels,
                        data_dir,
                        query_txt_file,
                        support_txt_file,
                        remap_config,
                        orientation,
                        prediction_path,
                        device=0,
                        logWriter=None,
                        mode='eval',
                        fold=None):
    print(
        "**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**"
    )
    print("Loading model => " + model_path)
    batch_size = 20
    Num_support = 15
    MC_samples = 10
    with open(query_txt_file) as file_handle:
        volumes_query = file_handle.read().splitlines()

    # with open(support_txt_file) as file_handle:
    #     volumes_support = file_handle.read().splitlines()

    model = torch.load(model_path)
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.empty_cache()
        model.cuda(device)

    model.eval()
    common_utils.create_if_not(prediction_path)

    print("Evaluating now... " + fold)
    query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file)
    support_file_paths = du.load_file_paths(data_dir, data_dir,
                                            support_txt_file)

    with torch.no_grad():
        all_query_dice_score_list = []

        for query_label in query_labels:
            volume_dice_score_list = []

            support_slices = []

            for i, file_path in enumerate(support_file_paths):
                # Loading support
                support_volume, support_labelmap, _, _ = du.load_and_preprocess(
                    file_path,
                    orientation=orientation,
                    remap_config=remap_config)

                support_volume = support_volume if len(
                    support_volume.shape
                ) == 4 else support_volume[:, np.newaxis, :, :]
                support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \
                                                   torch.tensor(support_labelmap).type(torch.LongTensor)

                support_volume, range_index = binarize_label(
                    support_volume, support_labelmap, query_label)

                slice_gap_support = int(
                    np.ceil(len(support_volume) / Num_support))

                support_slice_indexes = [
                    i for i in range(0, len(support_volume), slice_gap_support)
                ]

                if len(support_slice_indexes) < Num_support:
                    support_slice_indexes.append(len(support_volume) - 1)

                support_slices.extend(
                    [support_volume[idx] for idx in support_slice_indexes])

            for vol_idx, file_path in enumerate(query_file_paths):

                query_volume, query_labelmap, _, _ = du.load_and_preprocess(
                    file_path,
                    orientation=orientation,
                    remap_config=remap_config)

                query_volume = query_volume if len(
                    query_volume.shape) == 4 else query_volume[:, np.
                                                               newaxis, :, :]
                query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), \
                                               torch.tensor(query_labelmap).type(torch.LongTensor)

                query_labelmap = query_labelmap == query_label
                range_query = get_range(query_labelmap)
                query_volume = query_volume[range_query[0]:range_query[1] + 1]
                query_labelmap = query_labelmap[range_query[0]:range_query[1] +
                                                1]

                slice_gap_query = int(np.ceil(len(query_volume) / Num_support))
                dice_per_batch = []
                batch_output_arr = []
                for support_slice_idx, i in enumerate(
                        range(0, len(query_volume), slice_gap_query)):
                    query_batch_x = query_volume[i:i + slice_gap_query]
                    support_batch_x = support_volume[support_slice_idx].repeat(
                        query_batch_x.size()[0], 1, 1, 1)

                    if cuda_available:
                        query_batch_x = query_batch_x.cuda(device)
                        support_batch_x = support_batch_x.cuda(device)

                        weights = model.conditioner(support_batch_x)
                        out = model.segmentor(query_batch_x, weights)

                        _, batch_output = torch.max(F.softmax(out, dim=1),
                                                    dim=1)
                        batch_output_arr.append(batch_output)

                volume_output = torch.cat(batch_output_arr)
                volume_dice_score = dice_score_binary(
                    volume_output, query_labelmap.cuda(device), phase=mode)
                volume_dice_score_list.append(volume_dice_score.item())
                print(str(file_path), volume_dice_score)

            dice_score_arr = np.asarray(volume_dice_score_list)
            avg_dice_score = np.median(dice_score_arr)
            print(volume_dice_score_list)
            print('Query Label -> ' + str(query_label) + ' ' +
                  str(avg_dice_score))
            all_query_dice_score_list.append(avg_dice_score)

    print("DONE")

    return np.mean(all_query_dice_score_list)
コード例 #14
0
ファイル: evaluator.py プロジェクト: ai-med/AbdomenNet
def evaluate_dice_score(model_path,
                        num_classes,
                        data_dir,
                        label_dir,
                        volumes_txt_file,
                        orientation,
                        prediction_path,
                        device=0,
                        logWriter=None,
                        mode='eval',
                        multi_channel=False,
                        use_2channel=False,
                        thick_ch=False):
    log.info(
        "**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**"
    )
    batch_size = 15

    with open(volumes_txt_file) as file_handle:
        volumes_to_use = file_handle.read().splitlines()
    if multi_channel or use_2channel:
        file_paths = du.load_file_paths_3channel(data_dir, label_dir,
                                                 volumes_txt_file)
    else:
        file_paths = du.load_file_paths(data_dir, label_dir, volumes_txt_file)

    cuda_available = torch.cuda.is_available()
    # First, are we attempting to run on a GPU?
    if type(device) == int:
        # if CUDA available, follow through, else warn and fallback to CPU
        if cuda_available:
            model = torch.load(model_path)
            torch.cuda.empty_cache()
            model.cuda(device)
        else:
            log.warning(
                'CUDA is not available, trying with CPU.' + \
                'This can take much longer (> 1 hour). Cancel and ' + \
                'investigate if this behavior is not desired.'
            )
            # switch device to 'cpu'
            device = 'cpu'
    # If device is 'cpu' or CUDA not available
    if (type(device) == str) or not cuda_available:
        model = torch.load(model_path, map_location=torch.device(device))

    model.eval()

    common_utils.create_if_not(prediction_path)
    volume_dice_score_list = []
    log.info("Evaluating now...")

    with torch.no_grad():
        for vol_idx, file_path in enumerate(file_paths):
            if multi_channel:
                img, label, water, inv = nb.load(file_path[0]), nb.load(
                    file_path[1]), nb.load(file_path[2]), nb.load(file_path[4])
                volume, labelmap, water, inv, class_weights, weights, header, affine = img.get_fdata(
                ), label.get_fdata(), water.get_fdata(), inv.get_fdata(
                ), None, None, img.header, img.affine

                volume = np.rollaxis(volume, to_axis_dict[orientation], 0)
                labelmap = np.rollaxis(labelmap, to_axis_dict[orientation], 0)
                water = np.rollaxis(water, to_axis_dict[orientation], 0)
                # fat = np.rollaxis(fat, to_axis_dict[orientation], 0)
                inv = np.rollaxis(inv, to_axis_dict[orientation], 0)

                template = np.zeros_like(labelmap)
                volume, _, water, labelmap, inv, S, E = remove_black_3channels(
                    volume, None, water, labelmap, inv, return_indices=True)

                thick_volume = []
                for w, v, ij in zip(water, volume, inv):
                    thick_volume.append(np.stack([w, v, ij], axis=0))
                volume = np.array(thick_volume)

            elif use_2channel:
                img, label, water = nb.load(file_path[0]), nb.load(
                    file_path[1]), nb.load(file_path[2])
                volume, labelmap, water, class_weights, weights, header, affine = img.get_fdata(
                ), label.get_fdata(), water.get_fdata(
                ), None, None, img.header, img.affine

                volume = np.rollaxis(volume, to_axis_dict[orientation], 0)
                labelmap = np.rollaxis(labelmap, to_axis_dict[orientation], 0)
                water = np.rollaxis(water, to_axis_dict[orientation], 0)
                template = np.zeros_like(labelmap)
                volume, _, water, labelmap, _, S, E = remove_black_3channels(
                    volume, None, water, labelmap, None, return_indices=True)

                print(volume.shape, water.shape, labelmap.shape)
                thick_volume = []
                for v, w in zip(volume, water):
                    thick_volume.append(np.stack([w, v], axis=0))
                volume = np.array(thick_volume)
            else:
                img, label = nb.load(file_path[0]), nb.load(file_path[1])
                volume, labelmap, class_weights, weights, header, affine = img.get_fdata(
                ), label.get_fdata(), None, None, img.header, img.affine
                volume = np.rollaxis(volume, to_axis_dict[orientation], 0)
                labelmap = np.rollaxis(labelmap, to_axis_dict[orientation], 0)
                template = np.zeros_like(labelmap)
                volume, _, _, labelmap, _, S, E = remove_black_3channels(
                    volume, None, None, labelmap, None, return_indices=True)
                print(volume.shape, labelmap.shape)
            volume = volume if len(
                volume.shape) == 4 else volume[:, np.newaxis, :, :]
            volume, labelmap = torch.tensor(volume).type(
                torch.FloatTensor), torch.tensor(labelmap).type(
                    torch.LongTensor)

            volume_prediction = []
            for i in range(0, len(volume), batch_size):
                if multi_channel or use_2channel:
                    batch_x, batch_y = volume[i:i +
                                              batch_size], labelmap[i:i +
                                                                    batch_size]
                elif thick_ch:
                    batch_y = labelmap[i:i + batch_size]
                    batch_x = []
                    volume = np.squeeze(volume)
                    for bs in range(batch_size):
                        index = i + bs
                        if index < 2:
                            n1, n2 = index, index
                        else:
                            n1, n2 = index - 1, index - 2

                        if index >= volume.shape[0] - 3:
                            p1, p2 = index, index
                        else:
                            p1, p2 = index + 1, index + 2

                        batch_x.append(
                            np.stack([
                                volume[n2], volume[n1], volume[index],
                                volume[p1], volume[p2]
                            ],
                                     axis=0))
                    batch_x = np.array(batch_x)
                    batch_x = torch.tensor(batch_x).type(torch.FloatTensor)
                else:
                    batch_x, batch_y = volume[i:i +
                                              batch_size], labelmap[i:i +
                                                                    batch_size]

                if cuda_available and (type(device) == int):
                    batch_x = batch_x.cuda(device)
                out = model(batch_x)
                _, batch_output = torch.max(out, dim=1)
                volume_prediction.append(batch_output)

            volume_prediction = torch.cat(volume_prediction)
            volume_dice_score = dice_score_perclass(volume_prediction,
                                                    labelmap.cuda(device),
                                                    np.arange(0, num_classes),
                                                    mode=mode)

            volume_prediction = (
                volume_prediction.cpu().numpy()).astype('int16')
            print("evaluator here")
            header.set_data_dtype('int16')
            volume_prediction = np.squeeze(volume_prediction)

            template[S:E] = volume_prediction
            volume_prediction = np.rollaxis(template, 0,
                                            to_axis_dict[orientation] + 1)

            nifti_img = nb.Nifti1Image(volume_prediction,
                                       affine,
                                       header=header)
            nb.save(
                nifti_img,
                os.path.join(prediction_path,
                             volumes_to_use[vol_idx] + str('_new.nii.gz')))
            if logWriter:
                logWriter.plot_dice_score('val', 'eval_dice_score',
                                          volume_dice_score,
                                          volumes_to_use[vol_idx],
                                          np.arange(0,
                                                    num_classes), num_classes)

            volume_dice_score = volume_dice_score.cpu().numpy()
            volume_dice_score_list.append(volume_dice_score)
            log.info(volume_dice_score, np.mean(volume_dice_score))
        dice_score_arr = np.asarray(volume_dice_score_list)
        avg_dice_score = np.mean(dice_score_arr)
        avg_dice_score_wo_bg = np.mean(dice_score_arr[:, 1:])
        log.info("Mean of dice score : " + str(avg_dice_score))
        print('Mean dice score: ', avg_dice_score)
        print('Mean dice score without background: ', avg_dice_score_wo_bg)
        print('all dice scores: ', dice_score_arr)
        print('class wise mean dice scores: ', np.mean(dice_score_arr, axis=0))
        class_dist = [dice_score_arr[:, c] for c in range(num_classes)]

        if logWriter:
            logWriter.plot_eval_box_plot('eval_dice_score_box_plot',
                                         class_dist, 'Box plot Dice Score')
    log.info("DONE")

    return avg_dice_score, class_dist
コード例 #15
0
ファイル: evaluator.py プロジェクト: ai-med/AbdomenNet
def evaluate3view(coronal_model_path,
                  axial_model_path,
                  sagittal_model_path,
                  volumes_txt_file,
                  data_dir,
                  label_dir,
                  device,
                  prediction_path,
                  batch_size,
                  label_names,
                  label_list,
                  exit_on_error=False,
                  multi_channel=False,
                  use_2channel=False):
    log.info("**Starting evaluation**")

    with open(volumes_txt_file) as file_handle:
        volumes_to_use = file_handle.read().splitlines()

    if multi_channel or use_2channel:
        file_paths = du.load_file_paths_3channel(data_dir, label_dir,
                                                 volumes_txt_file)
    else:
        file_paths = du.load_file_paths(data_dir, label_dir, volumes_txt_file)

    cuda_available = torch.cuda.is_available()
    if type(device) == int:
        # if CUDA available, follow through, else warn and fallback to CPU
        if cuda_available:
            model1 = torch.load(coronal_model_path)
            model2 = torch.load(axial_model_path)
            model3 = torch.load(sagittal_model_path)

            torch.cuda.empty_cache()
            model1.cuda(device)
            model2.cuda(device)
            model3.cuda(device)
        else:
            log.warning(
                'CUDA is not available, trying with CPU.' + \
                'This can take much longer (> 1 hour). Cancel and ' + \
                'investigate if this behavior is not desired.'
            )

    if (type(device) == str) or not cuda_available:
        model1 = torch.load(coronal_model_path,
                            map_location=torch.device(device))
        model2 = torch.load(axial_model_path,
                            map_location=torch.device(device))
        model3 = torch.load(axial_model_path,
                            map_location=torch.device(device))

    model1.eval()
    model2.eval()
    model3.eval()

    common_utils.create_if_not(prediction_path)
    log.info("Evaluating now...")

    print(file_paths)

    with torch.no_grad():
        volume_dict_list = []
        cvs_dict_list = []
        iou_dict_list = []
        all_dice_scores = np.zeros((9))
        for vol_idx, file_path in enumerate(file_paths):

            volume_prediction_cor, (label,
                                    reference_label), _, header = _segment_vol(
                                        file_path, model1, "COR", batch_size,
                                        cuda_available, device, multi_channel,
                                        use_2channel)
            print('segment cor')
            volume_prediction_axi, (label,
                                    reference_label), _, header = _segment_vol(
                                        file_path, model2, "AXI", batch_size,
                                        cuda_available, device, multi_channel,
                                        use_2channel)
            print('segment axi')
            volume_prediction_sag, (label,
                                    reference_label), _, header = _segment_vol(
                                        file_path, model3, "SAG", batch_size,
                                        cuda_available, device, multi_channel,
                                        use_2channel)
            print('segment sag')

            volume_prediction_axi = F.softmax(volume_prediction_axi, dim=1)
            volume_prediction_cor = F.softmax(volume_prediction_cor, dim=1)
            volume_prediction_sag = F.softmax(volume_prediction_sag, dim=1)

            _, volume_prediction = torch.max(volume_prediction_axi +
                                             volume_prediction_sag +
                                             volume_prediction_cor,
                                             dim=1)

            volume_prediction = (
                volume_prediction.cpu().numpy()).astype('float32')

            reference_label = torch.from_numpy(reference_label).cuda(device)
            volume_dice_score = dice_score_perclass(
                torch.from_numpy(volume_prediction).cuda(device),
                reference_label,
                label_list,
                mode='eval')
            print(volume_dice_score)
            all_dice_scores += volume_dice_score.cpu().numpy()

            volume_prediction = np.squeeze(volume_prediction)
            volume_prediction = volume_prediction.astype('int')

            Mat = header.get_best_affine()

            nifti_img = nb.MGHImage(np.squeeze(volume_prediction),
                                    Mat,
                                    header=header)

            log.info("Processed: " + volumes_to_use[vol_idx] + " " +
                     str(vol_idx + 1) + " out of " + str(len(file_paths)))
            ax = axial_model_path.split('/')[-1].split('.')[0]
            co = coronal_model_path.split('/')[-1].split('.')[0]
            sa = sagittal_model_path.split('/')[-1].split('.')[0]
            common_utils.create_if_not(f'{prediction_path}/{ax}_{co}_{sa}')
            nb.save(
                nifti_img,
                os.path.join(f'{prediction_path}/{ax}_{co}_{sa}',
                             volumes_to_use[vol_idx] + str('.nii.gz')))

            del volume_prediction, volume_prediction_axi, volume_dice_score, volume_prediction_cor, volume_prediction_sag

        all_dice_scores /= len(file_paths)
        print('avg dice scores: ', all_dice_scores)
        print('mean dice: ', np.mean(all_dice_scores))
        print('mean dice without background: ', np.mean(all_dice_scores[1:]))

    log.info("DONE")
コード例 #16
0
def evaluate_dice_score(model_path,
                        num_classes,
                        query_labels,
                        data_dir,
                        query_txt_file,
                        support_txt_file,
                        remap_config,
                        orientation,
                        prediction_path,
                        device=0,
                        logWriter=None,
                        mode='eval',
                        fold=None):
    print(
        "**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**"
    )
    print("Loading model => " + model_path)
    batch_size = 20
    Num_support = 10
    with open(query_txt_file) as file_handle:
        volumes_query = file_handle.read().splitlines()

    # with open(support_txt_file) as file_handle:
    #     volumes_support = file_handle.read().splitlines()

    model = torch.load(model_path)
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.empty_cache()
        model.cuda(device)

    model.eval()

    common_utils.create_if_not(prediction_path)

    print("Evaluating now... " + fold)
    query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file)
    support_file_paths = du.load_file_paths(data_dir, data_dir,
                                            support_txt_file)

    with torch.no_grad():
        all_query_dice_score_list = []
        for query_label in query_labels:
            volume_dice_score_list = []

            # Loading support
            support_volume, support_labelmap, _, _ = du.load_and_preprocess(
                support_file_paths[0],
                orientation=orientation,
                remap_config=remap_config)
            support_volume = support_volume if len(
                support_volume.shape) == 4 else support_volume[:, np.
                                                               newaxis, :, :]
            support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \
                                               torch.tensor(support_labelmap).type(torch.LongTensor)

            support_volume, range_index = binarize_label(
                support_volume, support_labelmap, query_label)

            # # Save Input
            nifti_img = nib.MGHImage(
                np.squeeze(support_volume[:, 0, :, :].cpu().numpy()),
                np.eye(4))
            nib.save(
                nifti_img,
                os.path.join(prediction_path, 'SupportInput_' + str('.mgz')))

            nifti_img = nib.MGHImage(
                np.squeeze(support_volume[:, 1, :, :].cpu().numpy()),
                np.eye(4))
            nib.save(nifti_img,
                     os.path.join(prediction_path, 'SupportGT_' + str('.mgz')))

            print("Saved")

            slice_gap_support = int(np.ceil(len(support_volume) / Num_support))

            support_slice_indexes = [
                i for i in range(0, len(support_volume), slice_gap_support)
            ]

            if len(support_slice_indexes) < Num_support:
                support_slice_indexes.append(len(support_volume) - 1)

            for vol_idx, file_path in enumerate(query_file_paths):

                query_volume, query_labelmap, _, _ = du.load_and_preprocess(
                    file_path,
                    orientation=orientation,
                    remap_config=remap_config)

                query_volume = query_volume if len(
                    query_volume.shape) == 4 else query_volume[:, np.
                                                               newaxis, :, :]
                query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), \
                                               torch.tensor(query_labelmap).type(torch.LongTensor)

                query_labelmap = query_labelmap == query_label
                range_query = get_range(query_labelmap)
                query_volume = query_volume[range_query[0]:range_query[1] + 1]
                query_labelmap = query_labelmap[range_query[0]:range_query[1] +
                                                1]

                dice_per_slice = []
                vol_output = []
                for support_slice_idx in support_slice_indexes:
                    batch_output = []
                    for i in range(0, len(query_volume), batch_size):
                        query_batch_x = query_volume[i:i + batch_size]
                        support_batch_x = support_volume[
                            support_slice_idx].repeat(query_batch_x.size()[0],
                                                      1, 1, 1)
                        if cuda_available:
                            query_batch_x = query_batch_x.cuda(device)
                            support_batch_x = support_batch_x.cuda(device)
                        weights = model.conditioner(support_batch_x)
                        out = model.segmentor(query_batch_x, weights)

                        # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                        batch_output.append(out)
                    batch_output = torch.cat(batch_output)
                    vol_output.append(batch_output)
                vol_output = torch.stack(vol_output)
                vol_output = torch.mean(vol_output, dim=0)
                _, vol_output = torch.max(F.softmax(vol_output, dim=1), dim=1)

                # for i, query_slice in enumerate(query_volume):
                #     query_batch_x = query_slice.unsqueeze(0)
                #     max_dice = -1.0
                #     max_output = None
                #     for j in range(0, len(support_volume), 5):
                #         support_slice = support_volume[j]
                #
                #         support_batch_x = support_slice.unsqueeze(0)
                #         if cuda_available:
                #             query_batch_x = query_batch_x.cuda(device)
                #             support_batch_x = support_batch_x.cuda(device)
                #
                #         weights = model.conditioner(support_batch_x)
                #         out = model.segmentor(query_batch_x, weights)
                #
                #         _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                #         slice_dice_score = dice_score_binary(batch_output,
                #                                              query_labelmap[i].cuda(device), phase=mode)
                #         if slice_dice_score.item() >= max_dice:
                #             max_dice = slice_dice_score.item()
                #             max_output = batch_output
                #     # dice_per_slice.append(max_dice)
                #     vol_output.append(max_output)
                #
                # vol_output = torch.cat(vol_output)
                # volume_dice_score = np.mean(np.asarray(dice_per_slice))
                volume_dice_score = dice_score_binary(
                    vol_output, query_labelmap.cuda(device), phase=mode)
                volume_dice_score_list.append(volume_dice_score)

                print(volume_dice_score)

            dice_score_arr = np.asarray(volume_dice_score_list)
            avg_dice_score = np.median(dice_score_arr)
            print('Query Label -> ' + str(query_label) + ' ' +
                  str(avg_dice_score))
            all_query_dice_score_list.append(avg_dice_score)

    print("DONE")

    return np.mean(all_query_dice_score_list)