示例#1
0
def single_identification(scan_path,
                          detection_model_path,
                          identification_model_path,
                          plot_path,
                          spacing=(1.0, 1.0, 1.0)):
    scan_path_without_ext = scan_path[:-len(".nii.gz")]
    centroid_path = scan_path_without_ext + ".lml"

    labels, centroids = opening_files.extract_centroid_info_from_lml(
        centroid_path)
    centroid_indexes = centroids / np.array(spacing)

    cut = np.round(np.mean(centroid_indexes[:, 0])).astype(int)

    weights = np.array([0.1, 0.9])
    detection_model_objects = {
        'loss': weighted_categorical_crossentropy(weights),
        'binary_recall': km.binary_recall(),
        'dice_coef': dice_coef_label(label=1)
    }

    detection_model = load_model(detection_model_path,
                                 custom_objects=detection_model_objects)

    identification_model_objects = {
        'ignore_background_loss': ignore_background_loss,
        'vertebrae_classification_rate': vertebrae_classification_rate
    }

    identification_model = load_model(
        identification_model_path, custom_objects=identification_model_objects)

    volume = opening_files.read_nii(scan_path, spacing=spacing)

    detections = apply_detection_model(volume, detection_model,
                                       np.array([64, 64, 80]),
                                       np.array([32, 32, 40]))
    identification = apply_identification_model(volume, cut - 1, cut + 1,
                                                identification_model)

    volume_slice = volume[cut, :, :]
    detection_slice = detections[cut, :, :]
    identification_slice = identification[cut, :, :]

    identification_slice *= detection_slice

    masked_data = np.ma.masked_where(identification_slice == 0,
                                     identification_slice)

    fig, ax = plt.subplots(1)

    ax.imshow(volume_slice.T, cmap='gray')
    ax.imshow(masked_data.T,
              cmap=cm.jet,
              vmin=1,
              vmax=27,
              alpha=0.4,
              origin='lower')
    fig.savefig(plot_path + '/single_identification.png')
示例#2
0
def test_scan(scan_path,
              detection_model,
              detection_X_shape,
              detection_y_shape,
              identification_model,
              spacing=(2.0, 2.0, 2.0)):

    volume = opening_files.read_nii(scan_path, spacing)

    # first stage is to put the volume through the detection model to find where vertebrae are
    print("apply detection")
    detections = apply_detection_model(volume, detection_model,
                                       detection_X_shape, detection_y_shape)
    print("finished detection")

    # get the largest island
    # _, largest_island_np = sampling_helper_functions.crop_labelling(detections)
    largest_island_np = np.transpose(np.nonzero(detections))
    # largest_island_np = np.transpose(np.nonzero(largest_island_np)).astype(int)
    i_min = np.min(largest_island_np[:, 0])
    i_max = np.max(largest_island_np[:, 0])

    # second stage is to pass slices of this to the identification network
    print("apply identification")
    identifications = apply_identification_model(volume, i_min, i_max,
                                                 identification_model)
    print("finished identification")

    # crop parts of slices
    identifications *= detections
    print("finished multiplying")

    # aggregate the predictions
    print("start aggregating")
    identifications = np.round(identifications).astype(int)
    histogram = {}
    for key in range(1, len(LABELS_NO_L6)):
        histogram[key] = np.argwhere(identifications == key)
    '''
    for i in range(identifications.shape[0]):
        for j in range(identifications.shape[1]):
            for k in range(identifications.shape[2]):
                key = identifications[i, j, k]
                if key != 0:
                    if key in histogram:
                        histogram[key] = histogram[key] + [[i, j, k]]
                    else:
                        histogram[key] = [[i, j, k]]
    '''
    print("finish aggregating")

    print("start averages")
    # find averages
    labels = []
    centroid_estimates = []
    for key in sorted(histogram.keys()):
        if 0 <= key < len(LABELS_NO_L6):
            arr = histogram[key]
            # print(LABELS_NO_L6[key], arr.shape[0])
            if arr.shape[0] > max(VERTEBRAE_SIZES[LABELS_NO_L6[key]]**3 * 0.4,
                                  3000):
                print(LABELS_NO_L6[key], arr.shape[0])
                centroid_estimate = np.median(arr, axis=0)
                # ms = MeanShift(bin_seeding=True, min_bin_freq=300)
                # ms.fit(arr)
                # centroid_estimate = ms.cluster_centers_[0]
                centroid_estimate = np.around(centroid_estimate, decimals=2)
                labels.append(LABELS_NO_L6[key])
                centroid_estimates.append(list(centroid_estimate))
    print("finish averages")

    return labels, centroid_estimates, detections, identifications
示例#3
0
def complete_identification_picture(scans_dir,
                                    detection_model_path,
                                    identification_model_path,
                                    plot_path,
                                    start,
                                    end,
                                    spacing=(2.0, 2.0, 2.0)):
    scan_paths = glob.glob(scans_dir + "/**/*.nii.gz",
                           recursive=True)[start:end]
    no_of_scan_paths = len(scan_paths)

    weights = np.array([0.1, 0.9])
    detection_model_objects = {
        'loss': weighted_categorical_crossentropy(weights),
        'binary_recall': km.binary_recall(),
        'dice_coef': dice_coef_label(label=1)
    }

    detection_model = load_model(detection_model_path,
                                 custom_objects=detection_model_objects)

    identification_model_objects = {
        'ignore_background_loss': ignore_background_loss,
        'vertebrae_classification_rate': vertebrae_classification_rate
    }

    identification_model = load_model(
        identification_model_path, custom_objects=identification_model_objects)

    fig, axes = plt.subplots(nrows=1,
                             ncols=no_of_scan_paths,
                             figsize=(15, 6),
                             dpi=300)

    i = 1

    for col, scan_path in enumerate(scan_paths):
        print(i, scan_path)
        scan_path_without_ext = scan_path[:-len(".nii.gz")]
        centroid_path = scan_path_without_ext + ".lml"

        labels, centroids = opening_files.extract_centroid_info_from_lml(
            centroid_path)
        centroid_indexes = centroids / np.array(spacing)

        cut = np.round(np.mean(centroid_indexes[:, 0])).astype(int)

        scan_name = (scan_path.rsplit('/', 1)[-1])[:-len(".nii.gz")]
        axes[col].set_title(scan_name, fontsize=10, pad=10)

        detection_model_name = (detection_model_path.rsplit(
            '/', 1)[-1])[:-len(".h5")]
        identification_model_name = (identification_model_path.rsplit(
            '/', 1)[-1])[:-len(".h5")]
        name = detection_model_name + "\n" + identification_model_name
        # axes[0].set_ylabel(name, rotation=0, labelpad=50, fontsize=10)

        pred_labels, pred_centroid_estimates, pred_detections, pred_identifications = test_scan(
            scan_path=scan_path,
            detection_model=detection_model,
            detection_X_shape=np.array([64, 64, 80]),
            detection_y_shape=np.array([32, 32, 40]),
            identification_model=identification_model,
            spacing=spacing)

        volume = opening_files.read_nii(scan_path, spacing=spacing)

        volume_slice = volume[cut, :, :]
        # detections_slice = pred_detections[cut, :, :]
        identifications_slice = pred_identifications[cut, :, :]
        # identifications_slice = np.max(pred_identifications, axis=0)

        # masked_data = np.ma.masked_where(identifications_slice == 0, identifications_slice)
        # masked_data = np.ma.masked_where(detections_slice == 0, detections_slice)

        axes[col].imshow(volume_slice.T, cmap='gray', origin='lower')
        # axes[col].imshow(masked_data.T, vmin=1, vmax=27, cmap=cm.jet, alpha=0.4, origin='lower')

        for label, centroid_idx in zip(labels, centroid_indexes):
            u, v = centroid_idx[1:3]
            axes[col].annotate(label, (u, v), color="white", size=6)
            axes[col].scatter(u, v, color="white", s=8)

        axes[col].plot(centroid_indexes[:, 1],
                       centroid_indexes[:, 2],
                       color="white")

        for pred_label, pred_centroid_idx in zip(pred_labels,
                                                 pred_centroid_estimates):
            u, v = pred_centroid_idx[1:3]
            axes[col].annotate(pred_label, (u, v), color="red", size=6)
            axes[col].scatter(u, v, color="red", s=8)

        pred_centroid_estimates = np.array(pred_centroid_estimates)
        axes[col].plot(pred_centroid_estimates[:, 1],
                       pred_centroid_estimates[:, 2],
                       color="red")

        # get average distance
        total_difference = 0.0
        no = 0.0
        for pred_label, pred_centroid_idx in zip(pred_labels,
                                                 pred_centroid_estimates):
            if pred_label in labels:
                label_idx = labels.index(pred_label)
                print(pred_label, centroid_indexes[label_idx],
                      pred_centroid_idx)
                total_difference += np.linalg.norm(pred_centroid_idx -
                                                   centroid_indexes[label_idx])
                no += 1

        average_difference = total_difference / no
        print("average", average_difference)
        axes[col].set_xlabel("{:.2f}".format(average_difference) + "mm",
                             fontsize=10)

        i += 1

    fig.subplots_adjust(wspace=0.2, hspace=0.4)
    fig.savefig(plot_path + '/centroids_' + str(start) + '_' + str(end) +
                '.png')
示例#4
0
def compete_detection_picture(scans_dir,
                              models_dir,
                              plot_path,
                              spacing=(2.0, 2.0, 2.0)):

    scan_paths = glob.glob(scans_dir + "/**/*.nii.gz", recursive=True)
    model_paths = glob.glob(models_dir + "/*.h5", recursive=True)
    no_of_scan_paths = len(scan_paths)
    no_of_model_paths = len(model_paths)
    print("rows", no_of_model_paths, "cols", no_of_scan_paths)

    weights = np.array([0.1, 0.9])
    model_objects = {
        'loss': weighted_categorical_crossentropy(weights),
        'binary_recall': km.binary_recall(),
        'dice_coef': dice_coef_label(label=1)
    }

    fig, axes = plt.subplots(nrows=no_of_model_paths,
                             ncols=no_of_scan_paths,
                             figsize=(20, 10),
                             dpi=300)

    i = 1

    for col, scan_path in enumerate(scan_paths):

        scan_path_without_ext = scan_path[:-len(".nii.gz")]
        centroid_path = scan_path_without_ext + ".lml"

        _, centroids = opening_files.extract_centroid_info_from_lml(
            centroid_path)

        scan_name = (scan_path.rsplit('/', 1)[-1])[:-len(".nii.gz")]
        axes[0, col].set_title(scan_name, fontsize=10, pad=10)

        for row, model_path in enumerate(model_paths):
            print(i)

            size = np.array([30, 30, 36])
            current_spacing = spacing
            if model_path == "saved_current_models/detec-15:59.h5" \
                    or model_path == "saved_current_models/detec-15:59-20e.h5" :
                print("here")
                size = np.array([64, 64, 80])
                current_spacing = (1.0, 1.0, 1.0)

            centroid_indexes = centroids / np.array(current_spacing)
            cut = np.round(np.mean(centroid_indexes[:, 0])).astype(int)

            model_name = (model_path.rsplit('/', 1)[-1])[:-len(".h5")]
            axes[row, 0].set_ylabel(model_name,
                                    rotation=0,
                                    labelpad=50,
                                    fontsize=10)

            volume = opening_files.read_nii(scan_path, spacing=current_spacing)
            detection_model = load_model(model_path,
                                         custom_objects=model_objects)

            detections = apply_detection_model(volume, detection_model, size)

            volume_slice = volume[cut, :, :]
            detections_slice = detections[cut, :, :]

            masked_data = np.ma.masked_where(detections_slice == 0,
                                             detections_slice)

            axes[row, col].imshow(volume_slice.T, cmap='gray')
            axes[row, col].imshow(masked_data.T, cmap=cm.autumn, alpha=0.4)

            i += 1

    fig.subplots_adjust(wspace=-0.2, hspace=0.4)
    fig.savefig(plot_path + '/detection-complete.png')
def generate_samples(dataset_dir,
                     sample_dir,
                     spacing,
                     sample_size,
                     no_of_samples,
                     no_of_zero_samples,
                     file_ext=".nii.gz"):

    # numpy these so they can be divided later on
    sample_size = np.array(sample_size)

    ext_len = len(file_ext)

    paths = glob.glob(dataset_dir + "/**/*" + file_ext, recursive=True)

    np.random.seed(1)

    sample_size_np = np.array(sample_size, int)
    print("Generating " + str(no_of_samples * len(paths)) +
          " detection samples of size " + str(sample_size_np[0]) + " x " +
          str(sample_size_np[1]) + " x " + str(sample_size_np[2]) + " for " +
          str(len(paths)) + " scans")

    for cnt, data_path in enumerate(paths):
        # get path to corresponding metadata
        data_path_without_ext = data_path[:-ext_len]
        metadata_path = data_path_without_ext + ".lml"

        # get image, resample it and scale centroids accordingly
        labels, centroids = opening_files.extract_centroid_info_from_lml(
            metadata_path)
        centroid_indexes = np.round(centroids / np.array(spacing)).astype(int)

        volume = opening_files.read_nii(data_path, spacing=spacing)

        # densely populate
        disk_indices = pre_compute_disks(spacing)
        dense_labelling = densely_label(volume.shape,
                                        disk_indices,
                                        labels,
                                        centroid_indexes,
                                        use_labels=False)
        # dense_labelling = spherical_densely_label(volume.shape, 14.0, labels, centroid_indexes, use_labels=False)

        sample_size_in_pixels = (sample_size / np.array(spacing)).astype(int)

        # crop or pad depending on what is necessary
        if volume.shape[0] < sample_size_in_pixels[0]:
            dif = sample_size_in_pixels[0] - volume.shape[0]
            volume = np.pad(volume, ((0, dif), (0, 0), (0, 0)),
                            mode="constant",
                            constant_values=-5)
            dense_labelling = np.pad(dense_labelling,
                                     ((0, dif), (0, 0), (0, 0)),
                                     mode="constant")

        if volume.shape[1] < sample_size_in_pixels[1]:
            dif = sample_size_in_pixels[1] - volume.shape[1]
            volume = np.pad(volume, ((0, 0), (0, dif), (0, 0)),
                            mode="constant",
                            constant_values=-5)
            dense_labelling = np.pad(dense_labelling,
                                     ((0, 0), (0, dif), (0, 0)),
                                     mode="constant")

        if volume.shape[2] < sample_size_in_pixels[2]:
            dif = sample_size_in_pixels[2] - volume.shape[2]
            volume = np.pad(volume, ((0, 0), (0, 0), (0, dif)),
                            mode="constant",
                            constant_values=-5)
            dense_labelling = np.pad(dense_labelling,
                                     ((0, 0), (0, 0), (0, dif)),
                                     mode="constant")

        random_area = volume.shape - sample_size_in_pixels

        name = (data_path.rsplit('/', 1)[-1])[:-ext_len]
        i = 0
        j = 0
        while i < no_of_samples:

            random_factor = np.random.rand(3)
            random_position = np.round(random_area * random_factor).astype(int)
            corner_a = random_position
            corner_b = random_position + sample_size_in_pixels

            sample = volume[corner_a[0]:corner_b[0], corner_a[1]:corner_b[1],
                            corner_a[2]:corner_b[2]]
            labelling = dense_labelling[corner_a[0]:corner_b[0],
                                        corner_a[1]:corner_b[1],
                                        corner_a[2]:corner_b[2]]

            # if a centroid is contained
            unique_labels = np.unique(labelling).shape[0]
            if unique_labels > 1 or j < no_of_zero_samples:
                if unique_labels == 1:
                    j += 1
                i += 1

                # save file
                name_plus_id = name + "-" + str(i)
                path = '/'.join([sample_dir, name_plus_id])
                sample_path = path + "-sample"
                labelling_path = path + "-labelling"
                np.save(sample_path, sample)
                np.save(labelling_path, labelling)

        print(str(cnt + 1) + " / " + str(len(paths)))
示例#6
0
def generate_slice_samples(dataset_dir,
                           sample_dir,
                           sample_size=(40, 160),
                           spacing=(2.0, 2.0, 2.0),
                           no_of_samples=5,
                           no_of_vertebrae_in_each=2,
                           file_ext=".nii.gz"):
    sample_size = np.array(sample_size)
    ext_len = len(file_ext)

    paths = glob.glob(dataset_dir + "/**/*" + file_ext, recursive=True)

    np.random.seed(1)

    sample_size_np = np.array(sample_size, int)
    print("Generating " + str(no_of_samples * len(paths)) +
          " identification samples of size 8 x " + str(sample_size_np[0]) +
          " x " + str(sample_size_np[1]) + " for " + str(len(paths)) +
          " scans")

    for cnt, data_path in enumerate(paths):

        # get path to corresponding metadata
        data_path_without_ext = data_path[:-ext_len]
        metadata_path = data_path_without_ext + ".lml"

        volume = opening_files.read_nii(data_path, spacing=spacing)

        # print(volume.shape)
        labels, centroids = opening_files.extract_centroid_info_from_lml(
            metadata_path)
        centroid_indexes = np.round(centroids / np.array(spacing)).astype(int)

        disk_indices = pre_compute_disks(spacing)
        dense_labelling = densely_label(volume.shape,
                                        disk_indices,
                                        labels,
                                        centroid_indexes,
                                        use_labels=True)
        # dense_labelling = spherical_densely_label(volume.shape, 14.0, labels, centroid_indexes, use_labels=True)

        # dense_labelling_squashed = np.any(dense_labelling, axis=(1, 2))
        # lower_i = np.min(np.where(dense_labelling_squashed == 1))
        # upper_i = np.max(np.where(dense_labelling_squashed == 1))
        lower_i = np.min(centroid_indexes[:, 0])
        lower_i = np.max([lower_i - 15, 0]).astype(int)
        upper_i = np.max(centroid_indexes[:, 0])
        upper_i = np.min([upper_i + 15, volume.shape[0] - 1]).astype(int)

        cuts = []
        while len(cuts) < no_of_samples:
            # cut = np.random.randint(lower_i + 4, high=upper_i - 4)
            cut = np.random.randint(lower_i, high=upper_i)
            sample_labels_slice = dense_labelling[cut - 4:cut + 4, :, :]
            # sample_labels_slice = dense_labelling[cut, :, :]
            if np.unique(
                    sample_labels_slice).shape[0] > no_of_vertebrae_in_each:
                cuts.append(cut)

        name = (data_path.rsplit('/', 1)[-1])[:-ext_len]

        count = 0
        for i in cuts:

            volume_slice = volume[i - 4:i + 4, :, :]
            # volume_slice = volume[i, :, :]
            sample_labels_slice = dense_labelling[i, :, :]

            if volume_slice.shape[0] != 8:
                break

            # get vertebrae identification map
            # detection_slice = (sample_labels_slice > 0).astype(int)
            '''
            [volume_slice, sample_labels_slice] = elasticdeform.deform_random_grid(
                [volume_slice, sample_labels_slice], sigma=7, points=3, order=0)
            '''

            [volume_slice,
             sample_labels_slice] = elasticdeform.deform_random_grid(
                 [volume_slice,
                  np.expand_dims(sample_labels_slice, axis=0)],
                 sigma=7,
                 points=3,
                 order=0,
                 axis=(1, 2))

            sample_labels_slice = np.squeeze(sample_labels_slice, axis=0)

            # crop or pad depending on what is necessary
            if volume_slice.shape[1] < sample_size[0]:
                dif = sample_size[0] - volume_slice.shape[1]
                volume_slice = np.pad(volume_slice, ((0, 0), (0, dif), (0, 0)),
                                      mode="constant",
                                      constant_values=-5)
                # detection_slice = np.pad(detection_slice, ((0, dif), (0, 0)),
                #                         mode="constant")
                sample_labels_slice = np.pad(sample_labels_slice,
                                             ((0, dif), (0, 0)),
                                             mode="constant")

            if volume_slice.shape[2] < sample_size[1]:
                dif = sample_size[1] - volume_slice.shape[2]
                volume_slice = np.pad(volume_slice, ((0, 0), (0, 0), (0, dif)),
                                      mode="constant",
                                      constant_values=-5)
                # detection_slice = np.pad(detection_slice, ((0, 0), (0, dif)),
                #                         mode="constant")
                sample_labels_slice = np.pad(sample_labels_slice,
                                             ((0, 0), (0, dif)),
                                             mode="constant")
            '''
            if volume_slice.shape[0] < sample_size[0]:
                dif = sample_size[0] - volume_slice.shape[0]
                volume_slice = np.pad(volume_slice, ((0, dif), (0, 0)),
                                      mode="constant", constant_values=0)
                # detection_slice = np.pad(detection_slice, ((0, dif), (0, 0)),
                #                         mode="constant")
                sample_labels_slice = np.pad(sample_labels_slice, ((0, dif), (0, 0)),
                                             mode="constant")

            if volume_slice.shape[1] < sample_size[1]:
                dif = sample_size[1] - volume_slice.shape[1]
                volume_slice = np.pad(volume_slice, ((0, 0), (0, dif)),
                                      mode="constant", constant_values=0)
                # detection_slice = np.pad(detection_slice, ((0, 0), (0, dif)),
                #                         mode="constant")
                sample_labels_slice = np.pad(sample_labels_slice, ((0, 0), (0, dif)),
                                             mode="constant")
            '''

            # volume_slice = np.expand_dims(volume_slice, axis=2)
            # detection_slice = np.expand_dims(detection_slice, axis=2)
            # combines_slice = np.concatenate((volume_slice, detection_slice), axis=2)
            j = 0
            while True:
                random_area = volume_slice.shape[1:3] - sample_size
                # random_area = volume_slice.shape - sample_size
                random_factor = np.random.rand(2)
                random_position = np.round(random_area *
                                           random_factor).astype(int)
                corner_a = random_position
                corner_b = corner_a + sample_size

                cropped_combines_slice = volume_slice[:,
                                                      corner_a[0]:corner_b[0],
                                                      corner_a[1]:corner_b[1]]
                # cropped_combines_slice = volume_slice[corner_a[0]:corner_b[0], corner_a[1]:corner_b[1]]
                cropped_sample_labels_slice = sample_labels_slice[
                    corner_a[0]:corner_b[0], corner_a[1]:corner_b[1]]

                care_about_labels = np.count_nonzero(
                    cropped_sample_labels_slice)
                j += 1
                if care_about_labels > 500 or j > 100:
                    break

            # save file
            count += 1
            name_plus_id = name + "-" + str(count)
            path = '/'.join([sample_dir, name_plus_id])
            sample_path = path + "-sample"
            labelling_path = path + "-labelling"
            np.save(sample_path, cropped_combines_slice)
            np.save(labelling_path, cropped_sample_labels_slice)
        print(str(cnt + 1) + " / " + str(len(paths)))