def save_heatmaps(heatmap_malignant, heatmap_benign, short_file_path, view, horizontal_flip, parameters): """ Saves the heatmaps after flipping back to the original direction """ image_extension = '.hdf5' if parameters['use_hdf5'] else '.png' heatmap_malignant = loading.flip_image(heatmap_malignant, view, horizontal_flip) heatmap_benign = loading.flip_image(heatmap_benign, view, horizontal_flip) heatmap_save_path_malignant = os.path.join( parameters['save_heatmap_path'][0], short_file_path + image_extension) if heatmap_save_path_malignant.endswith("png"): saving_images.save_image_as_png(heatmap_malignant, heatmap_save_path_malignant) elif heatmap_save_path_malignant.endswith("hdf5"): saving_images.save_image_as_hdf5(heatmap_malignant, heatmap_save_path_malignant) else: raise RuntimeError() heatmap_save_path_benign = os.path.join(parameters['save_heatmap_path'][1], short_file_path + image_extension) if heatmap_save_path_benign.endswith("png"): saving_images.save_image_as_png(heatmap_benign, heatmap_save_path_benign) elif heatmap_save_path_benign.endswith("hdf5"): saving_images.save_image_as_hdf5(heatmap_benign, heatmap_save_path_benign) else: raise RuntimeError()
def produce_heatmaps(parameters): """ Generate heatmaps for single example """ random.seed(parameters['seed']) image_path = parameters["cropped_mammogram_path"] model, device = run_producer.load_model(parameters) metadata = pickling.unpickle_from_file(parameters['metadata_path']) patches, case = run_producer.sample_patches_single( image_path=image_path, view=metadata["view"], horizontal_flip=metadata['horizontal_flip'], parameters=parameters, ) all_prob = run_producer.get_all_prob( all_patches=patches, minibatch_size=parameters["minibatch_size"], model=model, device=device, parameters=parameters) heatmap_malignant, _ = run_producer.probabilities_to_heatmap( patch_counter=0, all_prob=all_prob, image_shape=case[0], length_stride_list=case[4], width_stride_list=case[3], patch_size=parameters['patch_size'], heatmap_type=parameters['heatmap_type'][0], ) heatmap_benign, patch_counter = run_producer.probabilities_to_heatmap( patch_counter=0, all_prob=all_prob, image_shape=case[0], length_stride_list=case[4], width_stride_list=case[3], patch_size=parameters['patch_size'], heatmap_type=parameters['heatmap_type'][1], ) heatmap_malignant = loading.flip_image( image=heatmap_malignant, view=metadata["view"], horizontal_flip=metadata["horizontal_flip"], ) heatmap_benign = loading.flip_image( image=heatmap_benign, view=metadata["view"], horizontal_flip=metadata["horizontal_flip"], ) saving_images.save_image_as_hdf5( image=heatmap_malignant, filename=parameters["heatmap_path_malignant"], ) saving_images.save_image_as_hdf5( image=heatmap_benign, filename=parameters["heatmap_path_benign"], )
def save_heatmaps(heatmap_malignant, heatmap_benign, short_file_path, view, horizontal_flip, parameters): """ Saves the heatmaps after flipping back to the original direction """ heatmap_malignant = loading.flip_image(heatmap_malignant, view, horizontal_flip) heatmap_benign = loading.flip_image(heatmap_benign, view, horizontal_flip) heatmap_save_path_malignant = os.path.join( parameters['save_heatmap_path'][0], short_file_path + '.hdf5') saving_images.save_image_as_hdf5(heatmap_malignant, heatmap_save_path_malignant) heatmap_save_path_benign = os.path.join(parameters['save_heatmap_path'][1], short_file_path + '.hdf5') saving_images.save_image_as_hdf5(heatmap_benign, heatmap_save_path_benign)
def extract_center(datum, image): """ Compute the optimal center for an image """ image = loading.flip_image(image, datum["full_view"], datum['horizontal_flip']) if datum["view"] == "MLO": tl_br_constraint = calc_optimal_centers.get_bottomrightmost_pixel_constraint( rightmost_x=datum["rightmost_points"][1], bottommost_y=datum["bottommost_points"][0], ) elif datum["view"] == "CC": tl_br_constraint = calc_optimal_centers.get_rightmost_pixel_constraint( rightmost_x=datum["rightmost_points"][1]) else: raise RuntimeError(datum["view"]) optimal_center = calc_optimal_centers.get_image_optimal_window_info( image, com=np.array(image.shape) // 2, window_dim=np.array(INPUT_SIZE_DICT[datum["full_view"]]), tl_br_constraint=tl_br_constraint, ) return optimal_center["best_center_y"], optimal_center["best_center_x"]