예제 #1
0
def load_inputs(image_path,
                metadata_path,
                use_heatmaps,
                benign_heatmap_path=None,
                malignant_heatmap_path=None):
    """
    Load a single input example, optionally with heatmaps
    """
    if use_heatmaps:
        assert benign_heatmap_path is not None
        assert malignant_heatmap_path is not None
    else:
        assert benign_heatmap_path is None
        assert malignant_heatmap_path is None
    metadata = pickling.unpickle_from_file(metadata_path)
    image = loading.load_image(
        image_path=image_path,
        view=metadata["full_view"],
        horizontal_flip=metadata["horizontal_flip"],
    )
    if use_heatmaps:
        heatmaps = loading.load_heatmaps(
            benign_heatmap_path=benign_heatmap_path,
            malignant_heatmap_path=malignant_heatmap_path,
            view=metadata["full_view"],
            horizontal_flip=metadata["horizontal_flip"],
        )
    else:
        heatmaps = None
    return ModelInput(image=image, heatmaps=heatmaps, metadata=metadata)
예제 #2
0
def convert_output_results(input_data_folder, heatmaps_path, 
                           data_path, output_data_folder):
    exam_list = pickling.unpickle_from_file(data_path)
    os.makedirs(os.path.dirname(output_data_folder), exist_ok=True)
    dcm_files = glob.glob(os.path.join(input_data_folder,"**","*.dcm"), recursive=True)
    image_extension = ".png"
    for datum in tqdm.tqdm(exam_list):
        loaded_heatmaps_dict = {view: [] for view in VIEWS.LIST}
        for view in VIEWS.LIST:
            for short_file_path in datum[view]:
                loaded_heatmaps = loading.load_heatmaps(
                    benign_heatmap_path=os.path.join(heatmaps_path, "heatmap_benign",
                                                     short_file_path + image_extension),
                    malignant_heatmap_path=os.path.join(heatmaps_path, "heatmap_malignant",
                                                        short_file_path + image_extension),
                    view=view,
                    horizontal_flip=datum["horizontal_flip"],
                )
                loaded_heatmaps_dict[view].append(loaded_heatmaps)
                loaded_heatmaps = np.stack([loaded_heatmaps[:,:,1:2], 
                                            loaded_heatmaps[:,:,0:1], 
                                            np.zeros(loaded_heatmaps[:,:,1:2].shape)], 
                                            axis=2)[:,:,:,0].astype(np.uint8)

                laterality = view.split("-")[0]
                projection_view = view.split("-")[1]
                dcm_file = find_view(dcm_files, laterality, projection_view)
                ds = pydicom.read_file(dcm_file)
                pixel_array = ds.pixel_array
                pixel_array = (pixel_array - np.min(pixel_array)) / float(np.max(pixel_array) - np.min(pixel_array)) * 255
                pixel_array = np.expand_dims(pixel_array.astype(np.uint8), axis=2)
                pixel_array = np.stack([pixel_array, pixel_array, pixel_array], axis=2)[:,:,:,0]
                coords = datum["window_location"][view][0]
                if laterality == "R":
                    loaded_heatmaps = np.fliplr(loaded_heatmaps)
                sub_pixel_array = pixel_array[coords[0]:coords[1],coords[2]:coords[3],0:3]
                bg = Image.fromarray(sub_pixel_array)
                fg = Image.fromarray(loaded_heatmaps)
                blended = Image.blend(bg, fg, 0.25)
                pixel_array[coords[0]:coords[1],coords[2]:coords[3],0:3] = np.asarray(blended)
                sc_image = SCImage()
                sc_image.create_empty_iod()
                sc_image.initiate()
                sc_image.set_dicom_attribute("PatientName", ds.PatientName)
                sc_image.set_dicom_attribute("PatientID", ds.PatientID)
                sc_image.set_dicom_attribute("AccessionNumber", ds.AccessionNumber)
                sc_image.set_dicom_attribute("StudyID", ds.StudyID)
                sc_image.set_dicom_attribute("StudyInstanceUID", ds.StudyInstanceUID)
                sc_image.set_dicom_attribute("StudyDate", ds.StudyDate if "StudyDate" in ds else "")
                sc_image.set_dicom_attribute("StudyTime", ds.StudyTime if "StudyTime" in ds else "")
                sc_image.set_dicom_attribute("StudyDescription", ds.StudyTime if "StudyDescription" in ds else "")
                sc_image.set_dicom_attribute("SeriesDescription", f"Original {view} + heatmap")
                # sc_image.add_pixel_data(loaded_heatmaps)
                sc_image.add_pixel_data(pixel_array)
                output_file = os.path.join(output_data_folder, "SC_" + view + ".dcm")
                sc_image.write_to_file(output_file)
예제 #3
0
파일: run_model.py 프로젝트: BigOsoft/IMA
def run_model(model, device, exam_list, parameters):
    """
    Returns predictions of image only model or image+heatmaps model.
    Prediction for each exam is averaged for a given number of epochs.
    """
    random_number_generator = np.random.RandomState(parameters["seed"])

    image_extension = ".hdf5" if parameters["use_hdf5"] else ".png"

    with torch.no_grad():
        predictions_ls = []
        for datum in tqdm.tqdm(exam_list):
            predictions_for_datum = []
            loaded_image_dict = {view: [] for view in VIEWS.LIST}
            loaded_heatmaps_dict = {view: [] for view in VIEWS.LIST}
            for view in VIEWS.LIST:
                for short_file_path in datum[view]:
                    loaded_image = loading.load_image(
                        image_path=os.path.join(
                            parameters["image_path"],
                            short_file_path + image_extension),
                        view=view,
                        horizontal_flip=datum["horizontal_flip"],
                    )
                    if parameters["use_heatmaps"]:
                        loaded_heatmaps = loading.load_heatmaps(
                            benign_heatmap_path=os.path.join(
                                parameters["heatmaps_path"], "heatmap_benign",
                                short_file_path + ".hdf5"),
                            malignant_heatmap_path=os.path.join(
                                parameters["heatmaps_path"],
                                "heatmap_malignant",
                                short_file_path + ".hdf5"),
                            view=view,
                            horizontal_flip=datum["horizontal_flip"],
                        )
                    else:
                        loaded_heatmaps = None

                    loaded_image_dict[view].append(loaded_image)
                    loaded_heatmaps_dict[view].append(loaded_heatmaps)
            for data_batch in tools.partition_batch(
                    range(parameters["num_epochs"]), parameters["batch_size"]):
                batch_dict = {view: [] for view in VIEWS.LIST}
                for _ in data_batch:
                    for view in VIEWS.LIST:
                        image_index = 0
                        if parameters["augmentation"]:
                            image_index = random_number_generator.randint(
                                low=0, high=len(datum[view]))
                        cropped_image, cropped_heatmaps = loading.augment_and_normalize_image(
                            image=loaded_image_dict[view][image_index],
                            auxiliary_image=loaded_heatmaps_dict[view]
                            [image_index],
                            view=view,
                            best_center=datum["best_center"][view]
                            [image_index],
                            random_number_generator=random_number_generator,
                            augmentation=parameters["augmentation"],
                            max_crop_noise=parameters["max_crop_noise"],
                            max_crop_size_noise=parameters[
                                "max_crop_size_noise"],
                        )
                        if loaded_heatmaps_dict[view][image_index] is None:
                            batch_dict[view].append(cropped_image[:, :,
                                                                  np.newaxis])
                        else:
                            batch_dict[view].append(
                                np.concatenate([
                                    cropped_image[:, :, np.newaxis],
                                    cropped_heatmaps,
                                ],
                                               axis=2))

                tensor_batch = {
                    view: torch.tensor(np.stack(batch_dict[view])).permute(
                        0, 3, 1, 2).to(device)
                    for view in VIEWS.LIST
                }
                output = model(tensor_batch)
                batch_predictions = compute_batch_predictions(
                    output, mode=parameters["model_mode"])
                pred_df = pd.DataFrame(
                    {k: v[:, 1]
                     for k, v in batch_predictions.items()})
                pred_df.columns.names = ["label", "view_angle"]
                predictions = pred_df.T.reset_index().groupby(
                    "label").mean().T[LABELS.LIST].values
                predictions_for_datum.append(predictions)
            predictions_ls.append(
                np.mean(np.concatenate(predictions_for_datum, axis=0), axis=0))

    return np.array(predictions_ls)
    def __getitem__(self, idx):
        batch_x_dict = self.x[idx * self.batch_size:(idx + 1) *
                              self.batch_size]

        batch_y = []
        batch_x = []
        batch_x_feed = {}
        batch_x_feed_new = {}
        for view in VIEWS.LIST:
            batch_x_feed_new[view] = []

        test_x_var = []
        test_y_var = []

        image_extension = ".hdf5" if self.parameters["use_hdf5"] else ".png"
        image_index = 0

        for datum in (
                batch_x_dict
        ):  # THIS AUGMENTATION PART IS PARTIALLY TAKEN FROM run_model.py

            patient = datum['L-CC'][0].split('_')[1]  # get name of patient
            patient_screens = []
            patient_screens_dict = {}

            patient_dict = {
                view: []
                for view in VIEWS.LIST
            }  # create structure that will be filled with images
            for view in VIEWS.LIST:
                patient_screens_dict[view] = []

                short_file_path = datum[view][
                    image_index]  # name of image file associated to that view
                test_x_var.append(short_file_path)

                loaded_image = loading.load_image(
                    image_path=os.path.join(self.parameters["image_path"],
                                            short_file_path + image_extension),
                    view=view,
                    horizontal_flip=datum["horizontal_flip"],
                )
                if self.parameters["use_heatmaps"]:
                    loaded_heatmaps = loading.load_heatmaps(
                        benign_heatmap_path=os.path.join(
                            self.parameters["heatmaps_path"], "heatmap_benign",
                            short_file_path + ".hdf5"),
                        malignant_heatmap_path=os.path.join(
                            self.parameters["heatmaps_path"],
                            "heatmap_malignant", short_file_path + ".hdf5"),
                        view=view,
                        horizontal_flip=datum["horizontal_flip"],
                    )
                else:
                    loaded_heatmaps = None

                if self.parameters["augmentation"]:
                    image_index = self.random_number_generator.randint(
                        low=0, high=len(datum[view]))
                cropped_image, cropped_heatmaps = loading.augment_and_normalize_image(  # producing cropped image
                    image=loaded_image,
                    auxiliary_image=loaded_heatmaps,
                    view=view,
                    best_center=datum["best_center"][view][image_index],
                    random_number_generator=self.random_number_generator,
                    augmentation=self.parameters["augmentation"],
                    max_crop_noise=self.parameters["max_crop_noise"],
                    max_crop_size_noise=self.parameters["max_crop_size_noise"],
                )
                if loaded_heatmaps is None:
                    patient_dict[view].append(
                        cropped_image[:, :, np.newaxis]
                    )  # adding the image to the list by specific view (i.e L-MLO, R-CC...)
                else:
                    patient_dict[view].append(
                        np.concatenate(
                            [  # adding the image and heatmap to the list by specific view (i.e L-MLO, R-CC...)
                                cropped_image[:, :, np.newaxis],
                                cropped_heatmaps,
                            ],
                            axis=2))

                batch_x_feed_new[view].append(
                    patient_dict[view][-1]
                )  # adding image of specific view to dictionary of patient

            batch_y.append(self.y[patient])  # output related to the patient

        batch_y = np.stack(np.array(batch_y), axis=0)

        # IF U WANT (BATCH, 4, 2) DECOMMENT THIS
        # for ni, i in enumerate(batch_y):
        #     if ni==0:
        #         #batch_y_tf = tf.expand_dims(tf.one_hot(i,2), 0)
        #         batch_y_tf = tf.expand_dims(tf.cast(tf.math.logical_not(tf.cast(tf.one_hot(i,2), dtype=tf.bool)), dtype='float32'), 0)
        #         #print(batch_y_tf)
        #     else:
        #         new_patient = tf.expand_dims(tf.cast(tf.math.logical_not(tf.cast(tf.one_hot(i,2), dtype=tf.bool)), dtype='float32'), 0)
        #         batch_y_tf = tf.concat([batch_y_tf, new_patient] , axis =0)
        # #print(batch_y_tf)
        # for n, view in enumerate(VIEWS.LIST):
        #     batch_x_feed[view] = np.moveaxis(np.stack(np.array(batch_x)[:, n], axis=0), -1,1) # dictionary: key=scan(i.e. 'L-CC', 'R-CC'...), value=tensor of scans (1 image per patient)

        for n, view in enumerate(VIEWS.LIST):
            batch_x_feed_new[view] = np.moveaxis(
                np.stack(np.array(batch_x_feed_new[view]), axis=0), -1,
                1)  # to obtain (1, width, height)

        # returns AS X a dict with all images for all patient patient organized by keys
        # that are the views (i.e. {'L-CC: [all L-CC iamges of patients per batch], 'L-MLO': ...})
        # returns as y a tensor of [batch, classes]
        return batch_x_feed_new, batch_y  #, [None]
예제 #5
0
def run_model(model, device, exam_list, parameters):
    """
    Returns predictions of image only model or image+heatmaps model.
    Prediction for each exam is averaged for a given number of epochs.
    """
    random_number_generator = np.random.RandomState(parameters["seed"])

    image_extension = ".hdf5" if parameters["use_hdf5"] else ".png"

    with torch.no_grad():
        predictions_ls = []
        for datum in tqdm.tqdm(exam_list):
            predictions_for_datum = []
            # F: VIEWS is an adhoc class
            # F: VIEWS.LIST : list of views as string
            loaded_image_dict = {view: [] for view in VIEWS.LIST}
            loaded_heatmaps_dict = {view: [] for view in VIEWS.LIST}
            for view in VIEWS.LIST:
                # F: for one exam, all images of a specific view
                for short_file_path in datum[view]:
                    loaded_image = loading.load_image(
                        image_path=os.path.join(
                            parameters["image_path"],
                            short_file_path + image_extension),
                        view=view,
                        horizontal_flip=datum["horizontal_flip"],
                    )
                    if parameters["use_heatmaps"]:
                        loaded_heatmaps = loading.load_heatmaps(
                            benign_heatmap_path=os.path.join(
                                parameters["heatmaps_path"], "heatmap_benign",
                                short_file_path + ".hdf5"),
                            malignant_heatmap_path=os.path.join(
                                parameters["heatmaps_path"],
                                "heatmap_malignant",
                                short_file_path + ".hdf5"),
                            view=view,
                            horizontal_flip=datum["horizontal_flip"],
                        )
                    else:
                        loaded_heatmaps = None

                    loaded_image_dict[view].append(loaded_image)
                    loaded_heatmaps_dict[view].append(loaded_heatmaps)
            # print(f"length loaded_image: {len(loaded_image_dict)}")
            for data_batch in tools.partition_batch(
                    range(parameters["num_epochs"]), parameters["batch_size"]):
                # print(f"num_epochs: {parameters['num_epochs']}")
                # print(f"batch_size: {parameters['batch_size']}")
                tmp = tools.partition_batch(range(parameters["num_epochs"]),
                                            parameters["batch_size"])
                # print(f"partition_batch: {tmp}")
                batch_dict = {view: [] for view in VIEWS.LIST}
                for _ in data_batch:
                    for view in VIEWS.LIST:
                        image_index = 0
                        # F: they use different augmentation for each view
                        if parameters["augmentation"]:
                            image_index = random_number_generator.randint(
                                low=0, high=len(datum[view]))

                        cropped_image, cropped_heatmaps = loading.augment_and_normalize_image(
                            image=loaded_image_dict[view][image_index],
                            auxiliary_image=loaded_heatmaps_dict[view]
                            [image_index],
                            view=view,
                            best_center=datum["best_center"][view]
                            [image_index],
                            random_number_generator=random_number_generator,
                            augmentation=parameters["augmentation"],
                            max_crop_noise=parameters["max_crop_noise"],
                            max_crop_size_noise=parameters[
                                "max_crop_size_noise"],
                        )
                        # print(f"cropped_image: {image_index} of m in minibatch: {_} size: {cropped_image.shape}")

                        if loaded_heatmaps_dict[view][image_index] is None:
                            batch_dict[view].append(cropped_image[:, :,
                                                                  np.newaxis])
                            # F: e.g. batch_dict[view][_].shape = (2974, 1748, 1)

                        else:
                            # F: e.g. batch_dict[view][:,:,1] is the first heatmap
                            batch_dict[view].append(
                                np.concatenate([
                                    cropped_image[:, :, np.newaxis],
                                    cropped_heatmaps,
                                ],
                                               axis=2))

                        # print(f"batch_dict_view: {len(batch_dict[view])}")
                        # print(f"batch_img_size: {batch_dict[view][_].shape}")

                tensor_batch = {
                    # F: result of np.stack has one more dimension:
                    # F: 4 dimensions: batch_data_i, y_pixels, x_pixels, channels
                    view: torch.tensor(np.stack(batch_dict[view])).permute(
                        0, 3, 1, 2).to(device)
                    for view in VIEWS.LIST
                }

                # print(f"layer_names: {model.state_dict().keys()}")
                # Print model's state_dict
                # print("Model's state_dict:")
                # for param_tensor in model.state_dict():
                # print(param_tensor, "\t", model.state_dict()[param_tensor].size())
                output = model(tensor_batch)
                batch_predictions = compute_batch_predictions(
                    output, mode=parameters["model_mode"])
                # print(f"batch_predictions: \n {batch_predictions}")
                # print(len(batch_predictions.keys()))
                # F: they pick value 1, disregarding value 0 which is the complement of that (prob = 1)
                pred_df = pd.DataFrame(
                    {k: v[:, 1]
                     for k, v in batch_predictions.items()})
                pred_df.columns.names = ["label", "view_angle"]
                # print(f"pred_df.head: {pred_df.head()}")
                # F: complicated way of grouping by label and calculating the mean
                predictions = pred_df.T.reset_index().groupby(
                    "label").mean().T[LABELS.LIST].values
                predictions_for_datum.append(predictions)
                # print(f"predictions: {predictions}")
            predictions_ls.append(
                np.mean(np.concatenate(predictions_for_datum, axis=0), axis=0))

    return np.array(predictions_ls)