コード例 #1
0
def compute_metrics_on_directories(root_dir,gt_name,pred_name,save_name):
    """
    Function to generate a csv file for each images of two directories.

    Parameters
    ----------

    path_gt: string
    Directory of the ground truth segmentation maps.

    path_pred: string
    Directory of the predicted segmentation maps.
    """
    patient_path_list=sorted([os.path.join(root_dir,i) for i in os.listdir(root_dir)])

    res = []
    for p_path in patient_path_list :
        p_name=p_path.split('/')[-1]
        print(p_name)
        gt, gt_image= load_nrrd(os.path.join(p_path,gt_name))
        pred, _= load_nrrd(os.path.join(p_path,pred_name))
        zooms = gt_image.GetSpacing()
        values=np.unique(pred)
        res.append(metrics(gt, pred, zooms))

    lst_name_gt = [gt.split("/")[-1] for gt in (patient_path_list)]
    res = [[n,] + r for r, n in zip(res, lst_name_gt)]
    df = pd.DataFrame(res, columns=HEADER)
    print (df.describe(include=[np.number]))
    df.to_csv(save_name+"_{}.csv".format(time.strftime("%Y%m%d_%H%M%S")), index=False)
コード例 #2
0
def compute_metrics_on_files(path_gt, path_pred):
    """
    Function to give the metrics for two files

    Parameters
    ----------

    path_gt: string
    Path of the ground truth image.

    path_pred: string
    Path of the predicted image.
    """
    gt,img= load_nrrd(path_gt)
    pred,_ = load_nrrd(path_pred)
    name = os.path.basename(path_gt)
    name = name.split('.')[0]
    res = metrics(gt, pred, img.GetSpacing())
    print ( 'spacing',img.GetSpacing())
    print (res)
コード例 #3
0
    def __getitem__(self, index):
        # update the seed to avoid workers sample the same augmentation parameters
        np.random.seed(datetime.datetime.now().second +
                       datetime.datetime.now().microsecond)

        if not self.preload_data:
            input, _ = load_nrrd(
                os.path.join(self.patient_path_list[index], 'lgemri.nrrd'))
            target, _ = load_nrrd(
                os.path.join(self.patient_path_list[index], 'laendo.nrrd'))
        else:
            input = np.copy(self.raw_images[index])
            target = np.copy(self.raw_labels[index])

        target = np.uint8(target)
        target[target >= 1] = 1  # binary
        has_object = 1 if np.sum(target) > 0 else 0

        ###slicing image at given orientation
        if self.orientation == 1:
            #coronal
            input = np.transpose(input, (1, 0, 2))
            target = np.transpose(target, (1, 0, 2))  #88*h/88*w
        elif self.orientation == 2:
            #sagittal
            input = np.transpose(input, (2, 0, 1))
            target = np.transpose(target, (2, 0, 1))

        if self.orientation == 0:
            # pass a random slice for the time being
            id = np.random.randint(0, input.shape[0])
            if self.if_subsequent:
                assert self.sequence_length >= 3
                half_length = self.sequence_length // 2
                id = np.random.randint(half_length,
                                       input.shape[0] - half_length)
        else:
            ### for sagittal and coronal slices need to do balance sampling before choose index
            p_indexes, n_indexes = self.get_p_n_group(target)
            p = np.random.rand()
            if p >= 0.5:
                temp_index = np.random.randint(0, len(p_indexes))
                id = p_indexes[temp_index]
            # print('choose p(has object) indexes')

            else:
                temp_index = np.random.randint(0, len(n_indexes))
                id = n_indexes[temp_index]
                #print('choose n (background) indexes')

        if not self.if_subsequent:
            input = input[[id], :, :]  # c*h*w
        else:
            start = id - self.sequence_length // 2
            end = id + self.sequence_length // 2 + 1
            print('{} {}'.format(start, end))
            input = input[start:end, :, :]
        target = target[id, :, :]

        original_data = input.copy()
        temp_input = np.zeros(input.shape, dtype=np.float)
        for i in range(input.shape[0]):
            if self.gamma_correction:
                input[i] = automatic_gamma_correction(input[i])

            if self.if_clahe:
                new_input = equalize_adapthist(input[i])
                temp_input[i] = new_input
        if self.if_clahe:
            input = input * 1.0
            input = temp_input

        if self.if_mip:
            assert input.shape[0] == 1
            ##select a blob which has 1/4 totalslices to combine a projection
            temp_input = merge_mip_stack(original_data,
                                         id,
                                         input[0],
                                         portion=4)  # 3chanel
            input = temp_input

        new_input, new_target = self.pair_transform(input, target,
                                                    self.input_h, self.input_w)
        #normalize data
        new_input = new_input * 1.0
        new_input_mean = np.mean(new_input, axis=(1, 2), keepdims=True)
        new_input -= new_input_mean
        new_std = np.std(new_input, axis=(1, 2), keepdims=True)
        new_input /= new_std + 0.00000000001

        input = torch.from_numpy(new_input).float()
        target = torch.from_numpy(new_target).long()
        if not self.extra_label:
            return {'input': input, 'target': target}
        else:
            patient_name = self.patient_path_list[index].split('/')[-1]
            post_ablation = self.df[self.df['Patient Code'] ==
                                    patient_name]['post_ablation?'].values[0]
            post_ablation = np.bool(post_ablation)
            class_label = 1 if post_ablation is True else 0
            # print (class_label)
            return {
                'input': input,
                'target': target,
                'p_name': patient_name,
                'post_ablation': class_label,
                'has_object': has_object
            }
コード例 #4
0
    def __init__(self,
                 root_dir,
                 split,
                 extra_label=False,
                 if_subsequent=False,
                 sequence_length=1,
                 if_mip=False,
                 extra_label_csv_path='',
                 augmentation=False,
                 if_clahe=False,
                 if_gamma_correction=False,
                 preload_data=False,
                 input_h=224,
                 input_w=224,
                 orientation=0):
        super(AtriaDataset, self).__init__()
        dataset_dir = join(root_dir, split)
        self.patient_list = os.listdir(dataset_dir)
        self.patient_path_list = sorted(
            [os.path.join(dataset_dir, pid) for pid in self.patient_list])
        self.data_size = len(self.patient_path_list)
        self.if_clahe = if_clahe
        # report the number of images in the dataset
        print('Number of {0} images: {1} nrrds'.format(split, self.data_size))

        # data augmentation
        self.augmentation = augmentation
        self.input_h = input_h
        self.input_w = input_w
        self.split = split
        self.gamma_correction = if_gamma_correction
        self.extra_label = extra_label

        # data load into the ram memory
        self.preload_data = preload_data
        if self.preload_data:
            print('Preloading the {0} dataset ...'.format(split))
            self.raw_images = [
                load_nrrd(os.path.join(ii, 'lgemri.nrrd'),
                          dtype=sitk.sitkUInt8)[0]
                for ii in self.patient_path_list
            ]
            self.raw_labels = [
                load_nrrd(os.path.join(ii, 'laendo.nrrd'))[0]
                for ii in self.patient_path_list
            ]
            print('Loading is done\n')

        ## add csv to get class label post ablation/not
        if self.extra_label:
            assert os.path.exists(extra_label_csv_path)
            print('loading csv as extra label')
            df = pd.read_csv(extra_label_csv_path, header=0)
            self.df = df
        self.if_mip = if_mip
        if self.if_mip:
            print('mip enabled')
        if self.gamma_correction:
            print('automatic gamma correction augmented')
        self.if_subsequent = if_subsequent  ## use former and next slices
        self.orientation = orientation
        '''
        0: get axial slices
        1: get coronary slices
        2: get sagittal slices 
        '''
        self.sequence_length = sequence_length
コード例 #5
0
if __name__ == '__main__':

    ### a sample script to produce a prediction

    # load the image file and reformat such that its axis are consistent with the MRI

    mask_format_name = "predict.nrrd"
    validation_dir = 'home/AtriaSeg_2018_testing'
    encode_cavity = []
    image_ids = []
    for patient_name in sorted(os.listdir(validation_dir)):
        print('encode ', patient_name)
        mask_path = join(*(validation_dir, patient_name, mask_format_name))
        if not os.path.isdir(os.path.join(validation_dir, patient_name)):
            continue
        mask, _ = load_nrrd(mask_path)
        mask[mask > 0] = 1
        # ***
        # encode in RLE
        image_ids.extend(
            [patient_name + "_Slice_" + str(i) for i in range(mask.shape[0])])

        for i in range(mask.shape[0]):
            encode_cavity.append(run_length_encoding(mask[i, :, :]))

    # output to csv file
    csv_output = pd.DataFrame(data={
        "ImageId": image_ids,
        'EncodeCavity': encode_cavity
    },
                              columns=['ImageId', 'EncodeCavity'])
コード例 #6
0
    image_name = 'lgemri.nrrd'
    test_dir = '/vol/medic01/users/cc215/data/AtriaSeg_2018_training/AtriaSeg_2018_testing'

    t=0.
    count=0
    for patient_name in sorted(os.listdir(test_dir)):
        start_time=time()
        print ('predict {}'.format(patient_name))
        patient_img_dir=join(test_dir,patient_name)
        if not os.path.isdir(patient_img_dir):
            continue
        count+=1
        patient_img_path=join(patient_img_dir,image_name)

        data,sitk_image=load_nrrd(patient_img_path,dtype=sitk.sitkUInt8)

        ## load model
        model_list=[]
        n_classes=2
        for k,v in model_dict.items():
            if 'multi_task' in k:
                model = MT_Net(n_channels=1, n_classes=2, n_labels=2, if_dropout=False, spp_grid=[8, 4, 1], upsample_type='bilinear')
            else:
                raise NotImplementedError

            cache=torch.load(v)

            model.load_state_dict(cache['model_state'])
            if gpu:
                model = torch.nn.DataParallel(model, device_ids=[0])
コード例 #7
0
        print("Using CUDA version of the net, prepare your GPU !")
        net.cuda()
    except:
        raise NotImplementedError
    print("Model loaded !")

    ##predict
    net.eval()
    t = 0
    for i, fn in enumerate(patient_path_list):

        start_time = time()
        print("\nPredicting image {} ...".format(fn))
        path = os.path.join(fn, PREDICT_IMAGE_NAME)
        if not os.path.exists(path): continue
        raw_data, sitk_image = load_nrrd(os.path.join(fn, PREDICT_IMAGE_NAME))
        print(raw_data.shape)
        ##
        if 'coronal' in args.model:
            raw_data = np.transpose(raw_data, (1, 0, 2))
        elif 'sagittal' in args.model:
            raw_data = np.transpose(raw_data, (2, 0, 1))

        result, original_result, result_prob_map = predict_img(
            sequence=sequence,
            if_mip=if_mip,
            if_gamma=args.gamma,
            if_clahe=if_clahe,
            n_classes=N_CLASEES,
            full_img=raw_data,
            batch_size=args.batch_size,