示例#1
0
    def save(
            self,
            name,
            prediction,
            prefix="",  #save_im_gdal_format=True,
            verbose=False):
        # AVE edit
        save_im_gdal_format = self.save_im_gdal_format
        if verbose:
            print("concrete_eval.py: prediction.shape:", prediction.shape)
            print("np.unique prediction:", np.unique(prediction))
        if len(prediction.shape) == 2:
            cv2.imwrite(os.path.join(self.save_dir, prefix + name),
                        (prediction * 255).astype(np.uint8))
        else:

            # skimage reads in (channels, h, w) for multi-channel
            # assume less than 20 channels
            #print ("mask_channels.shape:", mask_channels.shape)
            if prediction.shape[0] > 20:
                #print ("mask_channels.shape:", mask_channels.shape)
                mask = np.moveaxis(prediction, -1, 0)
                #print ("mask.shape:", mask.shape)
            else:
                mask = prediction
            if verbose:
                print("concrete_eval.py: mask.shape:", mask.shape)

            # save with skimage
            outfile_sk = os.path.join(self.save_dir, prefix + name)
            if verbose:
                print("name:", name)
                print("mask.shape:", mask.shape)
                print("mask.max:", mask.max())
                print("prediction.shape:", prediction.shape)
                print("outfile_sk:", outfile_sk)
                print(((mask * 255).astype(np.uint8)).shape)
            skimage.io.imsave(outfile_sk, (mask * 255).astype(np.uint8),
                              compress=1)
            gim = skimage.io.imread(outfile_sk)

            # also save with gdal?
            if save_im_gdal_format:
                save_dir_gdal = os.path.join(self.save_dir + '_gdal')
                #print ("save_dir_gdal:", save_dir_gdal)
                os.makedirs(save_dir_gdal, exist_ok=True)
                CreateMultiBandGeoTiff(
                    os.path.join(save_dir_gdal, prefix + name),
                    (mask * 255).astype(np.uint8))
def merge_tiffs(root,
                out_dir,
                out_dir_gdal=None,
                num_classes=1,
                verbose=False):

    prob_files = {
        f
        for f in os.listdir(root)
        if os.path.splitext(f)[1] in ['.tif', '.tiff']
    }
    print("prob_files:", prob_files)
    unfolded = {f[6:] for f in prob_files if f.startswith('fold')}
    #print ("unfolded:", unfolded)
    if not unfolded:
        unfolded = prob_files

    for prob_file in tqdm(unfolded):
        probs = []
        for fold in range(4):
            prob_path = os.path.join(root, 'fold{}_'.format(fold) + prob_file)

            if num_classes == 1:
                prob_arr = cv2.imread(prob_path, cv2.IMREAD_GRAYSCALE)
            elif num_classes == 3:
                prob_arr = cv2.imread(prob_path, 1)
            else:
                prob_arr_tmp = skimage.io.imread(prob_path)
                # we want skimage to read in (channels, h, w) for multi-channel
                #   assume less than 20 channels
                #print ("mask_channels.shape:", mask_channels.shape)
                if prob_arr_tmp.shape[0] > 20:
                    #print ("mask_channels.shape:", mask_channels.shape)
                    prob_arr = np.moveaxis(prob_arr_tmp, 0, -1)
                    #print ("mask.shape:", mask.shape)
                else:
                    prob_arr = prob_arr_tmp

            if verbose:
                print("prob_path:", prob_path)
                print("prob_arr.shape:", prob_arr.shape)
            probs.append(prob_arr)

        prob_arr_mean = np.mean(probs, axis=0).astype(np.uint8)
        if verbose:
            print("prob_arr_mean.shape:", prob_arr_mean.shape)
            print("prob_arr_mean.dtype:", prob_arr_mean.dtype)

        #res_path_geo = os.path.join(root, 'merged', prob_file)
        res_path_geo = os.path.join(out_dir, prob_file)
        if num_classes == 1 or num_classes == 3:
            cv2.imwrite(res_path_geo, prob_arr_mean)
        else:

            # skimage is slow,
            skimage.io.imsave(res_path_geo, prob_arr_mean, compress=1)

            # save gdal too?
            if out_dir_gdal:
                outpath_gdal = os.path.join(out_dir, prob_file)
                #outpath_gdal = os.path.join(out_dir_gdal, prob_file)
                # want chabnnels first
                # assume less than 20 channels
                if prob_arr_mean.shape[0] > 20:
                    #print ("mask_channels.shape:", mask_channels.shape)
                    mask_gdal = np.moveaxis(prob_arr_mean, -1, 0)
                    #print ("mask.shape:", mask.shape)
                else:
                    mask_gdal = prob_arr_mean
                CreateMultiBandGeoTiff(outpath_gdal, mask_gdal)