Ejemplo n.º 1
0
def score_data(input_folder,
               output_folder,
               model_path,
               num_classes=3,
               do_postprocessing=False,
               gt_exists=True,
               evaluate_all=False):

    nx, ny = exp_config.image_size[:2]
    batch_size = 1
    num_channels = num_classes + 1

    net = UNet(in_dim=1, out_dim=num_classes + 1).cuda()
    ckpt_path = os.path.join(model_path, 'best_model.pth.tar')
    net.load_state_dict(_pickle.load(open(ckpt_path, 'rb')))

    evaluate_test_set = not gt_exists

    total_time = 0
    total_volumes = 0

    for folder in os.listdir(input_folder):

        folder_path = os.path.join(input_folder, folder)

        if os.path.isdir(folder_path):

            if evaluate_test_set or evaluate_all:
                train_test = 'test'  # always test
            else:
                train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train'

            if train_test == 'test':

                infos = {}
                for line in open(os.path.join(folder_path, 'Info.cfg')):
                    label, value = line.split(':')
                    infos[label] = value.rstrip('\n').lstrip(' ')

                patient_id = folder.lstrip('patient')
                ED_frame = int(infos['ED'])
                ES_frame = int(infos['ES'])

                for file in glob.glob(
                        os.path.join(folder_path,
                                     'patient???_frame??.nii.gz')):

                    logging.info(
                        ' ----- Doing image: -------------------------')
                    logging.info('Doing: %s' % file)
                    logging.info(
                        ' --------------------------------------------')

                    file_base = file.split('.nii.gz')[0]

                    frame = int(file_base.split('frame')[-1])
                    img_dat = utils.load_nii(file)
                    img = img_dat[0].copy()
                    img = image_utils.normalise_image(img)

                    if gt_exists:
                        file_mask = file_base + '_gt.nii.gz'
                        mask_dat = utils.load_nii(file_mask)
                        mask = mask_dat[0]

                    start_time = time.time()

                    pixel_size = (img_dat[2].structarr['pixdim'][1],
                                  img_dat[2].structarr['pixdim'][2])
                    scale_vector = (pixel_size[0] /
                                    exp_config.target_resolution[0],
                                    pixel_size[1] /
                                    exp_config.target_resolution[1])

                    predictions = []

                    for zz in range(img.shape[2]):

                        slice_img = np.squeeze(img[:, :, zz])
                        slice_rescaled = transform.rescale(slice_img,
                                                           scale_vector,
                                                           order=1,
                                                           preserve_range=True,
                                                           multichannel=False,
                                                           mode='constant')

                        x, y = slice_rescaled.shape

                        x_s = (x - nx) // 2
                        y_s = (y - ny) // 2
                        x_c = (nx - x) // 2
                        y_c = (ny - y) // 2

                        # Crop section of image for prediction
                        if x > nx and y > ny:
                            slice_cropped = slice_rescaled[x_s:x_s + nx,
                                                           y_s:y_s + ny]
                        else:
                            slice_cropped = np.zeros((nx, ny))
                            if x <= nx and y > ny:
                                slice_cropped[x_c:x_c +
                                              x, :] = slice_rescaled[:,
                                                                     y_s:y_s +
                                                                     ny]
                            elif x > nx and y <= ny:
                                slice_cropped[:, y_c:y_c +
                                              y] = slice_rescaled[x_s:x_s +
                                                                  nx, :]
                            else:
                                slice_cropped[x_c:x_c + x, y_c:y_c +
                                              y] = slice_rescaled[:, :]

                        # GET PREDICTION
                        network_input = np.float32(
                            np.tile(np.reshape(slice_cropped, (nx, ny, 1)),
                                    (batch_size, 1, 1, 1)))
                        network_input = np.transpose(network_input,
                                                     [0, 3, 1, 2])
                        network_input = torch.cuda.FloatTensor(network_input)
                        with torch.no_grad():
                            net.eval()
                            logits_out = net(network_input)
                            softmax_out = F.softmax(logits_out, dim=1)
                            # mask_out = torch.argmax(logits_out, dim=1)
                            softmax_out = softmax_out.data.cpu().numpy()
                            softmax_out = np.transpose(softmax_out,
                                                       [0, 2, 3, 1])
                        # prediction_cropped = np.squeeze(softmax_out[0,...])
                        prediction_cropped = np.squeeze(softmax_out)

                        # ASSEMBLE BACK THE SLICES
                        slice_predictions = np.zeros((x, y, num_channels))
                        # insert cropped region into original image again
                        if x > nx and y > ny:
                            slice_predictions[x_s:x_s + nx, y_s:y_s +
                                              ny, :] = prediction_cropped
                        else:
                            if x <= nx and y > ny:
                                slice_predictions[:, y_s:y_s +
                                                  ny, :] = prediction_cropped[
                                                      x_c:x_c + x, :, :]
                            elif x > nx and y <= ny:
                                slice_predictions[
                                    x_s:x_s +
                                    nx, :, :] = prediction_cropped[:, y_c:y_c +
                                                                   y, :]
                            else:
                                slice_predictions[:, :, :] = prediction_cropped[
                                    x_c:x_c + x, y_c:y_c + y, :]

                        # RESCALING ON THE LOGITS
                        if gt_exists:
                            prediction = transform.resize(
                                slice_predictions,
                                (mask.shape[0], mask.shape[1], num_channels),
                                order=1,
                                preserve_range=True,
                                mode='constant')
                        else:  # This can occasionally lead to wrong volume size, therefore if gt_exists
                            # we use the gt mask size for resizing.
                            prediction = transform.rescale(
                                slice_predictions, (1.0 / scale_vector[0],
                                                    1.0 / scale_vector[1], 1),
                                order=1,
                                preserve_range=True,
                                multichannel=False,
                                mode='constant')

                        # prediction = transform.resize(slice_predictions,
                        #                               (mask.shape[0], mask.shape[1], num_channels),
                        #                               order=1,
                        #                               preserve_range=True,
                        #                               mode='constant')

                        prediction = np.uint8(np.argmax(prediction, axis=-1))
                        if num_classes == 1:
                            prediction[prediction == 1] = 3
                        elif num_classes == 2:
                            prediction[prediction == 2] = 3
                            prediction[prediction == 1] = 2
                        predictions.append(prediction)

                    prediction_arr = np.transpose(
                        np.asarray(predictions, dtype=np.uint8), (1, 2, 0))

                    # This is the same for 2D and 3D again
                    if do_postprocessing:
                        assert num_classes == 1
                        from skimage.measure import regionprops
                        lv_obj = (mask_dat[0] == 3).astype(np.uint8)
                        prop = regionprops(lv_obj)
                        assert len(prop) == 1
                        prop = prop[0]
                        centroid = prop.centroid
                        centroid = (int(centroid[0]), int(centroid[1]),
                                    int(centroid[2]))
                        prediction_arr = image_utils.keep_largest_connected_components(
                            prediction_arr, centroid)

                    elapsed_time = time.time() - start_time
                    total_time += elapsed_time
                    total_volumes += 1

                    logging.info('Evaluation of volume took %f secs.' %
                                 elapsed_time)

                    if frame == ED_frame:
                        frame_suffix = '_ED'
                    elif frame == ES_frame:
                        frame_suffix = '_ES'
                    else:
                        raise ValueError(
                            'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d'
                            % (frame, ED_frame, ES_frame))

                    # Save prediced mask
                    out_file_name = os.path.join(
                        output_folder, 'prediction',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    if gt_exists:
                        out_affine = mask_dat[1]
                        out_header = mask_dat[2]
                    else:
                        out_affine = img_dat[1]
                        out_header = img_dat[2]

                    logging.info('saving to: %s' % out_file_name)
                    utils.save_nii(out_file_name, prediction_arr, out_affine,
                                   out_header)

                    # Save image data to the same folder for convenience
                    image_file_name = os.path.join(
                        output_folder, 'image',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    logging.info('saving to: %s' % image_file_name)
                    utils.save_nii(image_file_name, img_dat[0], out_affine,
                                   out_header)

                    if gt_exists:

                        # Save GT image
                        gt_file_name = os.path.join(
                            output_folder, 'ground_truth',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        logging.info('saving to: %s' % gt_file_name)
                        utils.save_nii(gt_file_name, mask, out_affine,
                                       out_header)

                        # Save difference mask between predictions and ground truth
                        difference_mask = np.where(
                            np.abs(prediction_arr - mask) > 0, [1], [0])
                        difference_mask = np.asarray(difference_mask,
                                                     dtype=np.uint8)
                        diff_file_name = os.path.join(
                            output_folder, 'difference',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        logging.info('saving to: %s' % diff_file_name)
                        utils.save_nii(diff_file_name, difference_mask,
                                       out_affine, out_header)

    logging.info('Average time per volume: %f' % (total_time / total_volumes))

    return None
Ejemplo n.º 2
0
def score_data(input_folder,
               output_folder,
               model_path,
               args,
               do_postprocessing=False,
               gt_exists=True,
               evaluate_all=False,
               random_center_ratio=None):
    num_classes = args.num_cls
    nx, ny = exp_config.image_size[:2]
    batch_size = 1
    num_channels = num_classes + 1

    net = UNet(in_dim=1, out_dim=4).cuda()
    ckpt_path = os.path.join(model_path, 'best_model.pth.tar')
    net.load_state_dict(_pickle.load(open(ckpt_path, 'rb'))[0])
    if args.unet_ckpt:
        pretrained_unet = UNet(in_dim=1, out_dim=4).cuda()
        pretrained_unet.load_state_dict(
            _pickle.load(open(args.unet_ckpt, 'rb')))

    snake = SnakePytorch(args.delta, 1, args.num_lines, args.radius)

    evaluate_test_set = not gt_exists

    total_time = 0
    total_volumes = 0

    for folder in os.listdir(input_folder):

        folder_path = os.path.join(input_folder, folder)

        if os.path.isdir(folder_path):

            if evaluate_test_set or evaluate_all:
                train_test = 'test'  # always test
            else:
                train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train'

            if train_test == 'test':

                infos = {}
                for line in open(os.path.join(folder_path, 'Info.cfg')):
                    label, value = line.split(':')
                    infos[label] = value.rstrip('\n').lstrip(' ')

                patient_id = folder.lstrip('patient')
                ED_frame = int(infos['ED'])
                ES_frame = int(infos['ES'])

                for file in glob.glob(
                        os.path.join(folder_path,
                                     'patient???_frame??.nii.gz')):

                    logging.info(
                        ' ----- Doing image: -------------------------')
                    logging.info('Doing: %s' % file)
                    logging.info(
                        ' --------------------------------------------')

                    file_base = file.split('.nii.gz')[0]

                    frame = int(file_base.split('frame')[-1])
                    img_dat = utils.load_nii(file)
                    img = img_dat[0].copy()
                    img = image_utils.normalise_image(img)

                    if gt_exists:
                        file_mask = file_base + '_gt.nii.gz'
                        mask_dat = utils.load_nii(file_mask)
                        mask = mask_dat[0]

                    start_time = time.time()

                    pixel_size = (img_dat[2].structarr['pixdim'][1],
                                  img_dat[2].structarr['pixdim'][2])
                    scale_vector = (pixel_size[0] /
                                    exp_config.target_resolution[0],
                                    pixel_size[1] /
                                    exp_config.target_resolution[1])

                    predictions = []

                    for zz in range(img.shape[2]):

                        slice_img = np.squeeze(img[:, :, zz])
                        slice_rescaled = transform.rescale(slice_img,
                                                           scale_vector,
                                                           order=1,
                                                           preserve_range=True,
                                                           multichannel=False,
                                                           mode='constant')

                        x, y = slice_rescaled.shape
                        slice_cropped, x_s, y_s, x_c, y_c = get_slice(
                            slice_rescaled, nx, ny)
                        # GET PREDICTION
                        network_input = np.float32(
                            np.tile(np.reshape(slice_cropped, (nx, ny, 1)),
                                    (batch_size, 1, 1, 1)))
                        network_input = np.transpose(network_input,
                                                     [0, 3, 1, 2])
                        network_input = torch.cuda.FloatTensor(network_input)
                        with torch.no_grad():
                            net.eval()
                            logit = net(network_input)

                        # get the center
                        if args.unet_ckpt != '':
                            unet_mask = torch.argmax(
                                pretrained_unet(network_input),
                                dim=1).data.cpu().numpy()[0]
                        else:
                            assert gt_exists
                            mask_copy = mask[:, :, zz].copy()
                            unet_mask = get_slice(mask_copy, nx, ny)[0]
                        unet_mask = image_utils.keep_largest_connected_components(
                            unet_mask)
                        from data_iterator import get_center_of_mass
                        if num_classes == 2:
                            lv_center = get_center_of_mass(unet_mask, [3])
                            mo_center = get_center_of_mass(unet_mask, [2])
                        else:
                            lv_center = get_center_of_mass(unet_mask, [3])
                            mo_center = np.asarray([[np.nan, np.nan]])
                        lv_center = np.asarray(lv_center)
                        mo_center = np.asarray(mo_center)

                        lv_mask = np.zeros((nx, ny))
                        if not np.isnan(lv_center[0, 0]):
                            if random_center_ratio:
                                dt, _ = get_distance_transform(
                                    unet_mask == 3, None)
                                max_radius = dt[0,
                                                int(lv_center[0][0]),
                                                int(lv_center[0][1])]
                                radius = int(max_radius * random_center_ratio)
                                c_j, _ = get_random_jitter(radius, 0)
                            else:
                                c_j = None

                            lv_logit, _, _ = get_star_pattern_values(
                                logit[:, 3, ...],
                                None,
                                lv_center,
                                args.num_lines,
                                args.radius + 1,
                                center_jitters=c_j)
                            lv_gs = lv_logit[:, :,
                                             1:] - lv_logit[:, :, :
                                                            -1]  # compute the gradient
                            # run DP algo
                            # can only put batch with fixed shape into the snake algorithm
                            lv_ind = snake(lv_gs).data.cpu().numpy()
                            lv_ind = np.expand_dims(
                                smooth_ind(lv_ind.squeeze(-1),
                                           args.smoothing_window), -1)
                            lv_mask = star_pattern_ind_to_mask(
                                lv_ind, lv_center, nx, ny, args.num_lines,
                                args.radius)

                        if num_classes == 1:
                            pred_mask = lv_mask * 3
                        else:
                            mo_mask = np.zeros((nx, ny))
                            if not np.isnan(mo_center[0]):
                                c_j = None
                                mo_logit, _, _ = get_star_pattern_values(
                                    logit[:, 2, ...],
                                    None,
                                    lv_center,
                                    args.num_lines,
                                    args.radius + 1,
                                    center_jitters=c_j)
                                mo_gs = mo_logit[:, :,
                                                 1:] - mo_logit[:, :, :
                                                                -1]  # compute the gradient
                                mo_ind = snake(mo_gs).data.cpu().numpy()
                                mo_ind = mo_ind[:len(mo_gs), ...]
                                mo_ind = np.expand_dims(
                                    smooth_ind(mo_ind.squeeze(-1),
                                               args.smoothing_window), -1)
                                mo_mask = star_pattern_ind_to_mask(
                                    mo_ind, lv_center, nx, ny, args.num_lines,
                                    args.radius)
                            pred_mask = lv_mask * 3 + (
                                1 - lv_mask
                            ) * mo_mask * 2  # 3 is lv class, 2 is mo class

                        prediction_cropped = pred_mask.squeeze()
                        # ASSEMBLE BACK THE SLICES
                        prediction = np.zeros((x, y))
                        # insert cropped region into original image again
                        if x > nx and y > ny:
                            prediction[x_s:x_s + nx,
                                       y_s:y_s + ny] = prediction_cropped
                        else:
                            if x <= nx and y > ny:
                                prediction[:, y_s:y_s +
                                           ny] = prediction_cropped[x_c:x_c +
                                                                    x, :]
                            elif x > nx and y <= ny:
                                prediction[
                                    x_s:x_s +
                                    nx, :] = prediction_cropped[:, y_c:y_c + y]
                            else:
                                prediction[:, :] = prediction_cropped[x_c:x_c +
                                                                      x,
                                                                      y_c:y_c +
                                                                      y]

                        # RESCALING ON THE LOGITS
                        if gt_exists:
                            prediction = transform.resize(
                                prediction, (mask.shape[0], mask.shape[1]),
                                order=0,
                                preserve_range=True,
                                mode='constant')
                        else:  # This can occasionally lead to wrong volume size, therefore if gt_exists
                            # we use the gt mask size for resizing.
                            prediction = transform.rescale(
                                prediction,
                                (1.0 / scale_vector[0], 1.0 / scale_vector[1]),
                                order=0,
                                preserve_range=True,
                                multichannel=False,
                                mode='constant')

                        # prediction = np.uint8(np.argmax(prediction, axis=-1))
                        prediction = np.uint8(prediction)
                        predictions.append(prediction)

                        gt_binary = (mask[..., zz] == 3) * 1
                        pred_binary = (prediction == 3) * 1
                        from medpy.metric.binary import hd, dc, assd
                        lv_center = lv_center[0]
                        # i=0;  plt.imshow(network_input[0, 0]); plt.plot(lv_center[1], lv_center[0], 'ro'); plt.show(); plt.imshow(unet_mask); plt.plot(lv_center[1], lv_center[0], 'ro'); plt.show(); plt.imshow(logit[0, 0]); plt.plot(lv_center[1], lv_center[0], 'ro'); plt.show(); plt.imshow(lv_logit[0]); plt.show();  plt.imshow(lv_gs[0]); plt.show(); plt.imshow(prediction_cropped); plt.plot(lv_center[1], lv_center[0], 'r.'); plt.show();

                    prediction_arr = np.transpose(
                        np.asarray(predictions, dtype=np.uint8), (1, 2, 0))

                    # This is the same for 2D and 3D again
                    if do_postprocessing:
                        prediction_arr = image_utils.keep_largest_connected_components(
                            prediction_arr)

                    elapsed_time = time.time() - start_time
                    total_time += elapsed_time
                    total_volumes += 1

                    logging.info('Evaluation of volume took %f secs.' %
                                 elapsed_time)

                    if frame == ED_frame:
                        frame_suffix = '_ED'
                    elif frame == ES_frame:
                        frame_suffix = '_ES'
                    else:
                        raise ValueError(
                            'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d'
                            % (frame, ED_frame, ES_frame))

                    # Save prediced mask
                    out_file_name = os.path.join(
                        output_folder, 'prediction',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    if gt_exists:
                        out_affine = mask_dat[1]
                        out_header = mask_dat[2]
                    else:
                        out_affine = img_dat[1]
                        out_header = img_dat[2]

                    logging.info('saving to: %s' % out_file_name)
                    utils.save_nii(out_file_name, prediction_arr, out_affine,
                                   out_header)

                    # Save image data to the same folder for convenience
                    image_file_name = os.path.join(
                        output_folder, 'image',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    logging.info('saving to: %s' % image_file_name)
                    utils.save_nii(image_file_name, img_dat[0], out_affine,
                                   out_header)

                    if gt_exists:
                        # Save GT image
                        gt_file_name = os.path.join(
                            output_folder, 'ground_truth',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        logging.info('saving to: %s' % gt_file_name)
                        utils.save_nii(gt_file_name, mask, out_affine,
                                       out_header)

                        # Save difference mask between predictions and ground truth
                        difference_mask = np.where(
                            np.abs(prediction_arr - mask) > 0, [1], [0])
                        difference_mask = np.asarray(difference_mask,
                                                     dtype=np.uint8)
                        diff_file_name = os.path.join(
                            output_folder, 'difference',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        logging.info('saving to: %s' % diff_file_name)
                        utils.save_nii(diff_file_name, difference_mask,
                                       out_affine, out_header)

    logging.info('Average time per volume: %f' % (total_time / total_volumes))

    return None
Ejemplo n.º 3
0
def prepare_data(input_folder,
                 output_file,
                 mode,
                 size,
                 target_resolution,
                 split_test_train=True):
    '''
    Main function that prepares a dataset from the raw challenge data to an hdf5 dataset
    '''

    assert (mode in ['2D', '3D']), 'Unknown mode: %s' % mode
    if mode == '2D' and not len(size) == 2:
        raise AssertionError('Inadequate number of size parameters')
    if mode == '3D' and not len(size) == 3:
        raise AssertionError('Inadequate number of size parameters')
    if mode == '2D' and not len(target_resolution) == 2:
        raise AssertionError(
            'Inadequate number of target resolution parameters')
    if mode == '3D' and not len(target_resolution) == 3:
        raise AssertionError(
            'Inadequate number of target resolution parameters')

    hdf5_file = h5py.File(output_file, "w")

    diag_list = {'test': [], 'train': []}
    height_list = {'test': [], 'train': []}
    weight_list = {'test': [], 'train': []}
    patient_id_list = {'test': [], 'train': []}
    cardiac_phase_list = {'test': [], 'train': []}

    file_list = {'test': [], 'train': []}
    num_slices = {'test': 0, 'train': 0}

    logging.info('Counting files and parsing meta data...')

    for folder in os.listdir(input_folder):

        folder_path = os.path.join(input_folder, folder)

        if os.path.isdir(folder_path):

            if split_test_train:
                train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train'
            else:
                train_test = 'train'

            infos = {}
            for line in open(os.path.join(folder_path, 'Info.cfg')):
                label, value = line.split(':')
                infos[label] = value.rstrip('\n').lstrip(' ')

            patient_id = folder.lstrip('patient')

            for file in glob.glob(
                    os.path.join(folder_path, 'patient???_frame??.nii.gz')):

                file_list[train_test].append(file)

                # diag_list[train_test].append(diagnosis_to_int(infos['Group']))
                diag_list[train_test].append(diagnosis_dict[infos['Group']])
                weight_list[train_test].append(infos['Weight'])
                height_list[train_test].append(infos['Height'])

                patient_id_list[train_test].append(patient_id)

                systole_frame = int(infos['ES'])
                diastole_frame = int(infos['ED'])

                file_base = file.split('.')[0]
                frame = int(file_base.split('frame')[-1])
                if frame == systole_frame:
                    cardiac_phase_list[train_test].append(1)  # 1 == systole
                elif frame == diastole_frame:
                    cardiac_phase_list[train_test].append(2)  # 2 == diastole
                else:
                    cardiac_phase_list[train_test].append(
                        0)  # 0 means other phase

                nifty_img = nib.load(file)
                num_slices[train_test] += nifty_img.shape[2]

    # Write the small datasets
    for tt in ['test', 'train']:
        hdf5_file.create_dataset('diagnosis_%s' % tt,
                                 data=np.asarray(diag_list[tt],
                                                 dtype=np.uint8))
        hdf5_file.create_dataset('weight_%s' % tt,
                                 data=np.asarray(weight_list[tt],
                                                 dtype=np.float32))
        hdf5_file.create_dataset('height_%s' % tt,
                                 data=np.asarray(height_list[tt],
                                                 dtype=np.float32))
        hdf5_file.create_dataset('patient_id_%s' % tt,
                                 data=np.asarray(patient_id_list[tt],
                                                 dtype=np.uint8))
        hdf5_file.create_dataset('cardiac_phase_%s' % tt,
                                 data=np.asarray(cardiac_phase_list[tt],
                                                 dtype=np.uint8))

    if mode == '3D':
        nx, ny, nz_max = size
        n_train = len(file_list['train'])
        n_test = len(file_list['test'])

    elif mode == '2D':
        nx, ny = size
        n_test = num_slices['test']
        n_train = num_slices['train']

    else:
        raise AssertionError('Wrong mode setting. This should never happen.')

    # Create datasets for images and masks
    data = {}
    for tt, num_points in zip(['test', 'train'], [n_test, n_train]):

        if num_points > 0:
            data['images_%s' % tt] = hdf5_file.create_dataset(
                "images_%s" % tt, [num_points] + list(size), dtype=np.float32)
            data['masks_%s' % tt] = hdf5_file.create_dataset(
                "masks_%s" % tt, [num_points] + list(size), dtype=np.uint8)

    mask_list = {'test': [], 'train': []}
    img_list = {'test': [], 'train': []}

    logging.info('Parsing image files')

    train_test_range = ['test', 'train'] if split_test_train else ['train']
    for train_test in train_test_range:

        write_buffer = 0
        counter_from = 0

        for file in file_list[train_test]:

            logging.info(
                '-----------------------------------------------------------')
            logging.info('Doing: %s' % file)

            file_base = file.split('.nii.gz')[0]
            file_mask = file_base + '_gt.nii.gz'

            img_dat = utils.load_nii(file)
            mask_dat = utils.load_nii(file_mask)

            img = img_dat[0].copy()
            mask = mask_dat[0].copy()

            img = image_utils.normalise_image(img)

            pixel_size = (img_dat[2].structarr['pixdim'][1],
                          img_dat[2].structarr['pixdim'][2],
                          img_dat[2].structarr['pixdim'][3])

            logging.info('Pixel size:')
            logging.info(pixel_size)

            ### PROCESSING LOOP FOR 3D DATA ################################
            if mode == '3D':

                scale_vector = [
                    pixel_size[0] / target_resolution[0],
                    pixel_size[1] / target_resolution[1],
                    pixel_size[2] / target_resolution[2]
                ]

                img_scaled = transform.rescale(img,
                                               scale_vector,
                                               order=1,
                                               preserve_range=True,
                                               multichannel=False,
                                               mode='constant')
                mask_scaled = transform.rescale(mask,
                                                scale_vector,
                                                order=0,
                                                preserve_range=True,
                                                multichannel=False,
                                                mode='constant')

                slice_vol = np.zeros((nx, ny, nz_max), dtype=np.float32)
                mask_vol = np.zeros((nx, ny, nz_max), dtype=np.uint8)

                nz_curr = img_scaled.shape[2]
                stack_from = (nz_max - nz_curr) // 2

                if stack_from < 0:
                    raise AssertionError(
                        'nz_max is too small for the chosen through plane resolution. Consider changing'
                        'the size or the target resolution in the through-plane.'
                    )

                for zz in range(nz_curr):
                    slice_rescaled = img_scaled[:, :, zz]
                    mask_rescaled = mask_scaled[:, :, zz]

                    slice_cropped = crop_or_pad_slice_to_size(
                        slice_rescaled, nx, ny)
                    mask_cropped = crop_or_pad_slice_to_size(
                        mask_rescaled, nx, ny)

                    slice_vol[:, :, stack_from] = slice_cropped
                    mask_vol[:, :, stack_from] = mask_cropped

                    stack_from += 1

                img_list[train_test].append(slice_vol)
                mask_list[train_test].append(mask_vol)

                write_buffer += 1

                if write_buffer >= MAX_WRITE_BUFFER:
                    counter_to = counter_from + write_buffer
                    _write_range_to_hdf5(data, train_test, img_list, mask_list,
                                         counter_from, counter_to)
                    _release_tmp_memory(img_list, mask_list, train_test)

                    # reset stuff for next iteration
                    counter_from = counter_to
                    write_buffer = 0

            ### PROCESSING LOOP FOR SLICE-BY-SLICE 2D DATA ###################
            elif mode == '2D':

                scale_vector = [
                    pixel_size[0] / target_resolution[0],
                    pixel_size[1] / target_resolution[1]
                ]

                for zz in range(img.shape[2]):

                    slice_img = np.squeeze(img[:, :, zz])
                    slice_rescaled = transform.rescale(slice_img,
                                                       scale_vector,
                                                       order=1,
                                                       preserve_range=True,
                                                       multichannel=False,
                                                       mode='constant')

                    slice_mask = np.squeeze(mask[:, :, zz])
                    mask_rescaled = transform.rescale(slice_mask,
                                                      scale_vector,
                                                      order=0,
                                                      preserve_range=True,
                                                      multichannel=False,
                                                      mode='constant')

                    slice_cropped = crop_or_pad_slice_to_size(
                        slice_rescaled, nx, ny)
                    mask_cropped = crop_or_pad_slice_to_size(
                        mask_rescaled, nx, ny)

                    img_list[train_test].append(slice_cropped)
                    mask_list[train_test].append(mask_cropped)

                    write_buffer += 1

                    # Writing needs to happen inside the loop over the slices
                    if write_buffer >= MAX_WRITE_BUFFER:
                        counter_to = counter_from + write_buffer
                        _write_range_to_hdf5(data, train_test, img_list,
                                             mask_list, counter_from,
                                             counter_to)
                        _release_tmp_memory(img_list, mask_list, train_test)

                        # reset stuff for next iteration
                        counter_from = counter_to
                        write_buffer = 0

        # after file loop: Write the remaining data

        logging.info('Writing remaining data')
        counter_to = counter_from + write_buffer

        _write_range_to_hdf5(data, train_test, img_list, mask_list,
                             counter_from, counter_to)
        _release_tmp_memory(img_list, mask_list, train_test)

    # After test train loop:
    hdf5_file.close()
Ejemplo n.º 4
0
def compute_metrics_on_directories_raw(dir_gt, dir_pred):
    """
    Calculates a number of measures from the predicted and ground truth segmentations:
    - Dice
    - Hausdorff distance
    - Average surface distance
    - Predicted volume
    - Volume error w.r.t. ground truth

    :param dir_gt: Directory of the ground truth segmentation maps.
    :param dir_pred: Directory of the predicted segmentation maps.
    :return: Pandas dataframe with all measures in a row for each prediction and each structure
    """

    filenames_gt = sorted(glob(os.path.join(dir_gt, '*')), key=natural_order)
    filenames_pred = sorted(glob(os.path.join(dir_pred, '*')), key=natural_order)

    cardiac_phase = []
    file_names = []
    structure_names = []

    # 5 measures per structure:
    dices_list = []
    hausdorff_list = []
    assd_list = []
    vol_list = []
    vol_err_list = []

    structures_dict = {1: 'RV', 2: 'Myo', 3: 'LV'}

    for p_gt, p_pred in zip(filenames_gt, filenames_pred):
        if os.path.basename(p_gt) != os.path.basename(p_pred):
            raise ValueError("The two files don't have the same name"
                             " {}, {}.".format(os.path.basename(p_gt),
                                               os.path.basename(p_pred)))

        # load ground truth and prediction
        gt, _, header = utils.load_nii(p_gt)
        pred, _, _ = utils.load_nii(p_pred)
        zooms = header.get_zooms()

        # calculate measures for each structure
        for struc in [3, 2, 1]:

            gt_binary = (gt == struc) * 1
            pred_binary = (pred == struc) * 1

            volpred = pred_binary.sum() * np.prod(zooms) / 1000.
            volgt = gt_binary.sum() * np.prod(zooms) / 1000.

            vol_list.append(volpred)
            vol_err_list.append(volpred - volgt)
            if struc == 3:
                hausdorff_list.append(hd(gt_binary, pred_binary, voxelspacing=zooms, connectivity=1))
                assd_list.append(assd(pred_binary, gt_binary, voxelspacing=zooms, connectivity=1))
            else:
                hausdorff_list.append(0.0)
                assd_list.append(0.0)
            dices_list.append(dc(gt_binary, pred_binary))

            cardiac_phase.append(os.path.basename(p_gt).split('.nii.gz')[0].split('_')[-1])
            file_names.append(os.path.basename(p_pred))
            structure_names.append(structures_dict[struc])


    df = pd.DataFrame({'dice': dices_list, 'hd': hausdorff_list, 'assd': assd_list,
                       'vol': vol_list, 'vol_err': vol_err_list,
                      'phase': cardiac_phase, 'struc': structure_names, 'filename': file_names})

    return df