def save( self, name, prediction, prefix="", #save_im_gdal_format=True, verbose=False): # AVE edit save_im_gdal_format = self.save_im_gdal_format if verbose: print("concrete_eval.py: prediction.shape:", prediction.shape) print("np.unique prediction:", np.unique(prediction)) if len(prediction.shape) == 2: cv2.imwrite(os.path.join(self.save_dir, prefix + name), (prediction * 255).astype(np.uint8)) else: # skimage reads in (channels, h, w) for multi-channel # assume less than 20 channels #print ("mask_channels.shape:", mask_channels.shape) if prediction.shape[0] > 20: #print ("mask_channels.shape:", mask_channels.shape) mask = np.moveaxis(prediction, -1, 0) #print ("mask.shape:", mask.shape) else: mask = prediction if verbose: print("concrete_eval.py: mask.shape:", mask.shape) # save with skimage outfile_sk = os.path.join(self.save_dir, prefix + name) if verbose: print("name:", name) print("mask.shape:", mask.shape) print("mask.max:", mask.max()) print("prediction.shape:", prediction.shape) print("outfile_sk:", outfile_sk) print(((mask * 255).astype(np.uint8)).shape) skimage.io.imsave(outfile_sk, (mask * 255).astype(np.uint8), compress=1) gim = skimage.io.imread(outfile_sk) # also save with gdal? if save_im_gdal_format: save_dir_gdal = os.path.join(self.save_dir + '_gdal') #print ("save_dir_gdal:", save_dir_gdal) os.makedirs(save_dir_gdal, exist_ok=True) CreateMultiBandGeoTiff( os.path.join(save_dir_gdal, prefix + name), (mask * 255).astype(np.uint8))
def merge_tiffs(root, out_dir, out_dir_gdal=None, num_classes=1, verbose=False): prob_files = { f for f in os.listdir(root) if os.path.splitext(f)[1] in ['.tif', '.tiff'] } print("prob_files:", prob_files) unfolded = {f[6:] for f in prob_files if f.startswith('fold')} #print ("unfolded:", unfolded) if not unfolded: unfolded = prob_files for prob_file in tqdm(unfolded): probs = [] for fold in range(4): prob_path = os.path.join(root, 'fold{}_'.format(fold) + prob_file) if num_classes == 1: prob_arr = cv2.imread(prob_path, cv2.IMREAD_GRAYSCALE) elif num_classes == 3: prob_arr = cv2.imread(prob_path, 1) else: prob_arr_tmp = skimage.io.imread(prob_path) # we want skimage to read in (channels, h, w) for multi-channel # assume less than 20 channels #print ("mask_channels.shape:", mask_channels.shape) if prob_arr_tmp.shape[0] > 20: #print ("mask_channels.shape:", mask_channels.shape) prob_arr = np.moveaxis(prob_arr_tmp, 0, -1) #print ("mask.shape:", mask.shape) else: prob_arr = prob_arr_tmp if verbose: print("prob_path:", prob_path) print("prob_arr.shape:", prob_arr.shape) probs.append(prob_arr) prob_arr_mean = np.mean(probs, axis=0).astype(np.uint8) if verbose: print("prob_arr_mean.shape:", prob_arr_mean.shape) print("prob_arr_mean.dtype:", prob_arr_mean.dtype) #res_path_geo = os.path.join(root, 'merged', prob_file) res_path_geo = os.path.join(out_dir, prob_file) if num_classes == 1 or num_classes == 3: cv2.imwrite(res_path_geo, prob_arr_mean) else: # skimage is slow, skimage.io.imsave(res_path_geo, prob_arr_mean, compress=1) # save gdal too? if out_dir_gdal: outpath_gdal = os.path.join(out_dir, prob_file) #outpath_gdal = os.path.join(out_dir_gdal, prob_file) # want chabnnels first # assume less than 20 channels if prob_arr_mean.shape[0] > 20: #print ("mask_channels.shape:", mask_channels.shape) mask_gdal = np.moveaxis(prob_arr_mean, -1, 0) #print ("mask.shape:", mask.shape) else: mask_gdal = prob_arr_mean CreateMultiBandGeoTiff(outpath_gdal, mask_gdal)