def create_models(images, feature_set: FeatureSet, base_data_path, extension='tif', main_window_size=(30, 30), percentage_threshold=0.5, class_ratio=1.3, bands=WORLDVIEW3, cached=True): """ Yields models, a tuple of (X vector, y vector, real_mask image) for various images Also yields a grouped model of all images for classification use :param cached: :param images: :param feature_set: :param base_data_path: :param extension: :param main_window_size: :param percentage_threshold: :param class_ratio: :param bands: """ data = [] for group_num, image_name in enumerate(images): image_file = "{base_path}/{image_name}.{extension}".format( base_path=base_data_path, image_name=image_name, extension=extension) mask_full_path = "{base_path}/{image_name}_masked.tif".format( base_path=base_data_path, image_name=image_name) sat_image = SatelliteImage.load_from_file(image_file, bands) X = get_x_matrix(sat_image, image_name=image_name, feature_set=feature_set, window_size=main_window_size, cached=cached) y, real_mask = get_y_vector(mask_full_path, main_window_size, percentage_threshold, cached=False) # X, y = balance_dataset(X, y, class_ratio=class_ratio) print("X shape {}, y shape {}".format(X.shape, y.shape)) image_vars = (X, y, real_mask, np.full(y.shape, group_num)) data.append(image_vars) # yield image_vars if len(data) > 1: group_num = 0 base_X, base_y, _, groups = data[0] groups = np.full(base_y.shape, group_num) print("X shape {}, y shape {}".format(base_X.shape, base_y.shape)) for X, y, _, im_groups in data[1:]: base_X = np.append(base_X, X, axis=0) base_y = np.append(base_y, y, axis=0) groups = np.append(groups, im_groups, axis=0) return (base_X, base_y, None, groups)
def load_image(image_name): bands = WORLDVIEW3 base_path = data_path() image_file = "{base_path}/{image_name}.{extension}".format( base_path=base_path, image_name=image_name, extension=extension) return SatelliteImage.load_from_file(image_file, bands)
def load_image(): # URI to the image imagefile = '/home/bweel/Documents/projects/dynaslum/data/satelite/056239125010_01/056239125010_01_P001_MUL/08NOV02054348-M2AS_R1C1-056239125010_01_P001.TIF' # Set the correct format here, it is used throughout the notebook bands = QUICKBIRD # Loading the file image = SatelliteImage.load_from_file(imagefile, bands) return image
def get_y_vector(binary_file_path: str, smallest_window_size: tuple, percentage_threshold: float = 0.5, cached: bool = False) -> tuple: # TODO: Fix cache key. if cached: y_train = load_cache('y_train') if y_train is not None: return y_train dataset = gdal.Open(binary_file_path, gdal.GA_ReadOnly) array = dataset.ReadAsArray() array = np.min(array, 0) array = array[:, :, np.newaxis] binary_sat_image = SatelliteImage(dataset, array, MASK_BANDS) generator = CellGenerator(image=binary_sat_image, size=smallest_window_size) # Mask which covers the whole image (exploded back up to the real dims) real_mask = np.zeros(array.shape, dtype=np.uint8) # Y matrix in dims of the blocks y_matrix = np.zeros(generator.shape()) for window in generator: # for name, feature in iteritems(features.items): y = 0 unique, counts = np.unique(window.raw, return_counts=True) # total = np.sum(counts) # above_n = np.sum(counts[unique > median]) # below_n = total - above_n # percentage_above = above_n / total # if percentage_above > percentage_threshold: # y = 1 if unique[0] == 0: zeros = counts[0] non_zeros = np.sum(counts[1:]) if non_zeros / (zeros + non_zeros) > percentage_threshold: y = 1 else: y = 1 y_matrix[window.x, window.y] = y real_mask[window.x_range, window.y_range, 0] = y y_train = y_matrix.flatten() if cached: cache(y_train, "y_train") return y_train, real_mask
class_ratio = 1.3 feature_set = FeatureSet() pantex = Pantex(pantex_window_sizes) for image_name in images: for lac_box_size in (10, 20, 30): for lac_window_size in ((50, 50), (300, 300), (500, 500),): lac_window_size = (lac_window_size,) image_file = "{base_path}/{image_name}.{extension}".format( base_path=base_path, image_name=image_name, extension=extension ) bands = WORLDVIEW3 mask_full_path = "{base_path}/{image_name}_masked.tif".format(base_path=base_path, image_name=image_name) sat_image = SatelliteImage.load_from_file(image_file, bands) sat_image_shape = sat_image.shape # sift = create_sift_feature(sat_image, ((25, 25), (50, 50), (100, 100)), image_name, n_clusters=n_clusters, # cached=True) # lacunarity = create_lacunarity(sat_image, image_name, windows=((25, 25),), cached=True) lacunarity = Lacunarity(windows=lac_window_size, box_sizes=(lac_box_size,)) feature_set.add(lacunarity, "LACUNARITY") # feature_set.add(pantex, "PANTEX") # feature_set.add(sift, "SIFT") classifier = RF_classifier() # del sat_image # Free-up memory
def plot_overlap(y_pred, groups, im_num, image_name, mask_full_path, main_window_size, current_time, results_path): dataset = gdal.Open(mask_full_path, gdal.GA_ReadOnly) array = dataset.ReadAsArray() array = np.min(array, 0) array = array[:, :, np.newaxis] truth_mask = np.where(array > 0, 1, 0) # unique, counts = np.unique(array, return_counts=True) # median = np.median(unique) # array = np.where(array > 0, 1, 0) # binary_sat_image = SatelliteImage.load_from_file(binary_file_path, bands=mask_bands) binary_sat_image = SatelliteImage(dataset, array, MASK_BANDS) generator = CellGenerator(image=binary_sat_image, size=main_window_size) result_mask = np.zeros(array.shape, dtype=np.uint8) y_matrix = np.zeros(generator.shape()) y_pred_im = y_pred[groups == im_num] print("unique y_pred", np.unique(y_pred_im, return_counts=True)) print(y_pred_im.shape) print(y_pred_im) print("Gen shape", generator.x_length, generator.y_length) i = 0 for window in generator: # for name, feature in iteritems(features.items): y = 255 if y_pred_im[i] == 1 else 0 if y != 0: print(i, y) # unique, counts = np.unique(window.raw, return_counts=True) # # total = np.sum(counts) # # above_n = np.sum(counts[unique > median]) # # below_n = total - above_n # # percentage_above = above_n / total # # if percentage_above > percentage_threshold: # # y = 1 # # if unique[0] == 0: # zeros = counts[0] # non_zeros = np.sum(counts[1:]) # if non_zeros / (zeros + non_zeros) > percentage_threshold: # y = 255 # else: # y = 255 y_matrix[window.x, window.y] = y result_mask[window.x_range, window.y_range, 0] = y i += 1 ds, img, bands = load_from_file(image_file, WORLDVIEW3) img = normalize_image(img, bands) rgb_img = get_rgb_bands(img, bands) grayscale = get_grayscale_image(img, bands) plt.figure() plt.axis('off') plt.imshow(rgb_img) # plt.imshow(grayscale, cmap='gray') print("Counts:", np.unique(result_mask, return_counts=True)) binary_mask = result_mask show_mask = np.ma.masked_where(binary_mask == 0, binary_mask) plt.imshow(show_mask[:, :, 0], cmap='jet', interpolation='none', alpha=1.0) # plt.title('Binary mask') plt.savefig("{}/classification_results_{}_{}.png".format(results_path, image_name, current_time)) plt.show() plt.figure() plt.axis('off') plt.imshow(binary_mask[:, :, 0], cmap='jet', interpolation='none', alpha=1.0) plt.savefig("{}/classification_mask_results_{}_{}.png".format(results_path, image_name, current_time)) plt.show() print('Min {} Max {}'.format(binary_mask.min(), binary_mask.max())) print('Len > 0: {}'.format(len(binary_mask[binary_mask > 0]))) print('Len == 0: {}'.format(len(binary_mask[binary_mask == 0]))) jaccard_index = jaccard_index_binary_masks(truth_mask[:, :, 0], binary_mask[:, :, 0]) print("Jaccard index: {}".format(jaccard_index)) return jaccard_index
pantex = Pantex(pantex_window_sizes) lacunarity = Lacunarity(windows=((100, 100), (200, 200), (300, 300))) feature_set.add(lacunarity, "LACUNARITY") feature_set.add(pantex, "PANTEX") best_score = 0 for im_num, image_name in enumerate(images): image_file = "{base_path}/{image_name}.{extension}".format( base_path=base_path, image_name=image_name, extension=extension ) bands = WORLDVIEW3 mask_full_path = "{base_path}/{image_name}_masked.tif".format(base_path=base_path, image_name=image_name) sat_image = SatelliteImage.load_from_file(image_file, bands) sat_image_shape = sat_image.shape sift = create_sift_feature(sat_image, ((25, 25), (50, 50), (100, 100)), image_name, n_clusters=n_clusters, cached=True) texton = create_texton_feature(sat_image, ((25, 25), (50, 50), (100, 100)), image_name, n_clusters=n_clusters, cached=True) feature_set.add(sift, "SIFT") feature_set.add(texton, "TEXTON") for feature_set, classifier in generate_tests((texton, sift, pantex, lacunarity)): plt.close('all') print("Running feature set {}, image {}, classifier {}".format(feature_set, image_name, str(classifier))) # X = get_x_matrix(sat_image, image_name=image_name, feature_set=feature_set, window_size=main_window_size, # cached=True) # y, real_mask = get_y_vector(mask_full_path, main_window_size, percentage_threshold, cached=False) # X, y = balance_dataset(X, y, class_ratio=class_ratio) # X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42, stratify=None)
dataset = gdal.Open(out_file, gdal.GA_ReadOnly) # dataset = dataset[0, :, :] array = dataset.ReadAsArray() print(array.shape) array = np.min(array, 0) array = array[:, :, np.newaxis] truth_mask = np.where(array > 0, 1, 0) # unique, counts = np.unique(array, return_counts=True) # median = np.median(unique) # array = np.where(array > 0, 1, 0) # binary_sat_image = SatelliteImage.load_from_file(binary_file_path, bands=mask_bands) binary_sat_image = SatelliteImage(dataset, array, MASK_BANDS) generator = CellGenerator(image=binary_sat_image, size=smallest_window_size) result_mask = np.zeros(array.shape, dtype=np.uint8) y_matrix = np.zeros(generator.shape) for window in generator: # for name, feature in iteritems(features.items): y = 0 unique, counts = np.unique(window.raw, return_counts=True) # total = np.sum(counts) # above_n = np.sum(counts[unique > median]) # below_n = total - above_n # percentage_above = above_n / total # if percentage_above > percentage_threshold: # y = 1
def plot_overlap(y_pred, image_name, image_full_path, mask_full_path, main_window_size, current_time, results_path): dataset = gdal.Open(mask_full_path, gdal.GA_ReadOnly) array = dataset.ReadAsArray() array = np.min(array, 0) array = array[:, :, np.newaxis] truth_mask = np.where(array > 0, 1, 0) # binary_sat_image = SatelliteImage.load_from_file(binary_file_path, bands=mask_bands) binary_sat_image = SatelliteImage(dataset, array, MASK_BANDS) generator = CellGenerator(image=binary_sat_image, size=main_window_size) result_mask = np.zeros(array.shape, dtype=np.uint8) y_matrix = np.zeros(generator.shape()) # y_pred_im = y_pred[groups == im_num] print("unique y_pred", np.unique(y_pred, return_counts=True)) print(y_pred.shape) print(y_pred) print("Gen shape", generator.x_length, generator.y_length, generator.x_length * generator.y_length) print("result mask shape", result_mask.shape) print("{} == {}".format(generator.x_length * generator.y_length, y_pred.shape)) i = 0 y_expected = 0 for window in generator: y = 0 if i < y_pred.shape[0] >= i: if y_pred[i] == 0: y = 0 if y_pred[i] == 1: y = 255 y_matrix[window.x, window.y] = y result_mask[window.x_range, window.y_range, 0] = y i += 1 if y > 0: y_expected += 30 * 30 print("{} == {}".format(y_expected, len(result_mask[result_mask > 0]))) print("Total iterations", i) print("Y_matrix counts", np.unique(y_matrix, return_counts=True)) print("Counts:", np.unique(result_mask, return_counts=True)) print("result_mask[255s]", len(result_mask[result_mask == 255])) print("result_mask[0s]", len(result_mask[result_mask == 0])) ds, img, bands = load_from_file(image_full_path, WORLDVIEW3) img = normalize_image(img, bands) rgb_img = get_rgb_bands(img, bands) grayscale = get_grayscale_image(img, bands) plt.figure() plt.axis('off') plt.imshow(rgb_img) # plt.imshow(np.zeros(rgb_img.shape)[:, :, 0], cmap='gray') # plt.imshow(grayscale, cmap='gray') binary_mask = result_mask show_mask = np.ma.masked_where(binary_mask == 0, binary_mask) plt.imshow(show_mask[:, :, 0], cmap='jet', interpolation='none', alpha=1.0) # plt.title('Binary mask') plt.savefig("{}/classification_jaccard_results_{}_{}.png".format( results_path, image_name, current_time)) plt.show() plt.figure() plt.axis('off') plt.imshow(binary_mask[:, :, 0], cmap='jet', interpolation='none', alpha=1.0) plt.savefig("{}/classification_jaccard_mask_results_{}_{}.png".format( results_path, image_name, current_time)) plt.show() print('Min {} Max {}'.format(binary_mask.min(), binary_mask.max())) print('Len > 0: {}'.format(len(binary_mask[binary_mask > 0]))) print('Len == 0: {}'.format(len(binary_mask[binary_mask == 0]))) jaccard_index = jaccard_index_binary_masks(truth_mask[:, :, 0], binary_mask[:, :, 0]) print("Jaccard index: {}".format(jaccard_index)) return jaccard_index