def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('datasets', type=str, nargs='+')
    parser.add_argument('--training', action='store_true')
    parser.add_argument('--rgb', action='store_true')
    parser.add_argument('--mul', action='store_true')
    parser.add_argument('--pan', action='store_true')
    args = parser.parse_args()
    buffer_meters = 2
    burnValue = 255

    path_apls = r'/wdata'
    path_png = os.path.join(path_apls, 'output_png')
    test = not args.training
    save_rgb = args.rgb
    save_mul = args.mul
    save_pan = args.pan
    path_outputs = os.path.join(path_apls, 'train' if not test else 'test', 'masks{}m'.format(buffer_meters))
    path_images_8bit = os.path.join(path_apls, 'train' if not test else 'test', 'images')
    for d in [path_outputs, path_images_8bit, path_png]:
        shutil.rmtree(d, ignore_errors=True)
        os.makedirs(d, exist_ok=True)

    for path_data in args.datasets:
        path_data = path_data.rstrip('/')
        test_data_name = os.path.split(path_data)[-1]
        test_data_name = '_'.join(test_data_name.split('_')[:3]) + '_'
        path_images_raw = os.path.join(path_data, 'RGB-PanSharpen')
        path_images_mul = os.path.join(path_data, 'MUL')
        path_images_pan = os.path.join(path_data, 'PAN')
        path_labels = os.path.join(path_data, 'geojson/spacenetroads')
        # iterate through images, convert to 8-bit, and create masks
        im_files = os.listdir(path_images_raw)
        m = defaultdict(list)
        for im_file in im_files:
            if not im_file.endswith('.tif'):
                continue

            name_root = im_file.split('_')[-1].split('.')[0]

            # create 8-bit image
            im_file_raw = os.path.join(path_images_raw, im_file)
            im_file_out = os.path.join(path_images_8bit, test_data_name + name_root + '.tif')
            im_file_rgb = os.path.join(path_png,test_data_name + name_root + "_rgb.png")
            im_file_mul = os.path.join(path_png,test_data_name + name_root + "_mul.png")
            im_file_pan = os.path.join(path_png,test_data_name + name_root + "_pan.png")


            # continue
            rescale_type = test_data_name.split('_')[1]
            if not os.path.isfile(im_file_out):
                apls_tools.convert_to_8Bit(im_file_raw, im_file_out,
                                           outputPixType='Byte',
                                           outputFormat='GTiff',
                                           rescale_type=rescale[rescale_type],
                                           percentiles=[2,98])
            
            if not os.path.isfile(im_file_rgb) and save_rgb:
                apls_tools.convert_to_8Bit(im_file_raw, im_file_rgb,
                                           outputPixType='Byte',
                                           outputFormat='png',
                                           percentiles=[0,100])
            
            if not os.path.isfile(im_file_mul) and save_mul:
                apls_tools.convert_to_8Bit(im_file_mul, im_file_mul,
                                           outputPixType='Byte',
                                           outputFormat='png',
                                           percentiles=[0,100])

            if not os.path.isfile(im_file_pan) and save_pan:
                apls_tools.convert_to_8Bit(im_file_pan, im_file_pan,
                                           outputPixType='Byte',
                                           outputFormat='png',
                                           percentiles=[0,100])

            if test:
                continue
            # determine output files
            label_file = os.path.join(path_labels, 'spacenetroads_' + test_data_name + name_root + '.geojson')
            label_file_tot = os.path.join(path_labels, label_file)
            output_raster = os.path.join(path_outputs, test_data_name + name_root + '.png')

            print("\nname_root:", name_root)
            print("  output_raster:", output_raster)

            # create masks
            mask, gdf_buffer = apls_tools.get_road_buffer(label_file_tot, im_file_out,
                                                          output_raster,
                                                          buffer_meters=buffer_meters,
                                                          burnValue=burnValue,
                                                          bufferRoundness=6,
                                                          plot_file=None,
                                                          figsize= (6,6),  #(13,4),
                                                          fontsize=8,
                                                          dpi=200, show_plot=False,
                                                          verbose=False)

        for k, v in m.items():
            print(test_data_name, k, np.mean(v, axis=0))
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config_path')
    parser.add_argument('--training', action='store_true')
    #parser.add_argument('datasets', type=str, nargs='+')
    #parser.add_argument('mode', type=str, default='test', help='test or train')
    #parser.add_argument('--buffer_meters', type=float, default=2, help='road width (m)')
    args = parser.parse_args()
    #buffer_meters = args.buffer_meters

    # get config
    with open(args.config_path, 'r') as f:
        cfg = json.load(f)
    config = Config(**cfg)

    # set values
    if config.num_channels == 3:
        image_format_path = 'RGB-PanSharpen'
    else:
        image_format_path = 'MUL-PanSharpen'

    imfile_prefix = image_format_path + '_'
    label_path_extra = 'geojson/spacenetroads'
    geojson_prefix = 'spacenetroads_'
    burnValue = 255

    buffer_meters = float(config.mask_width_m)
    buffer_meters_str = str(np.round(buffer_meters, 1)).replace('.', 'p')
    test = not args.training

    paths_data_raw = []
    #############
    # output directories

    # put all training images in one directory so training can find em all
    if not test:
        path_masks = os.path.join(config.path_data_root,
                                  config.train_data_refined_dir,
                                  'masks{}m'.format(buffer_meters_str))
        path_images_8bit = os.path.join(config.path_data_root,
                                        config.train_data_refined_dir,
                                        'images')
        # make dirs
        for d in [path_masks, path_images_8bit]:
            print("cleaning and remaking:", d)
            shutil.rmtree(d, ignore_errors=True)
            os.makedirs(d, exist_ok=True)

        # set path_data_raw
        for dpart in config.data_train_raw_parts.split(','):
            paths_data_raw.append(os.path.join(config.path_data_root, dpart))

    else:
        path_masks = os.path.join(config.path_data_root,
                                  config.test_data_refined_dir,
                                  'masks{}m'.format(buffer_meters_str))
        path_images_8bit = os.path.join(config.path_data_root,
                                        config.test_data_refined_dir)
        # make dirs
        for d in [path_images_8bit]:
            print("cleaning and remaking:", d)
            shutil.rmtree(d, ignore_errors=True)
            os.makedirs(d, exist_ok=True)
        # set path_data_raw
        for dpart in config.data_test_raw_parts.split(','):
            paths_data_raw.append(os.path.join(config.path_data_root, dpart))

    #path_masks =       os.path.join(config.path_data_root, config.data_refined_name+'_train' if not test else config.data_refined_name+'_test', 'masks{}m'.format(buffer_meters))
    #path_images_8bit = os.path.join(config.path_data_root, config.data_refined_name+'_train' if not test else config.data_refined_name+'_test', 'images')

    # make dirs
    for d in [path_masks, path_images_8bit]:
        print("cleaning and remaking:", d)
        shutil.rmtree(d, ignore_errors=True)
        os.makedirs(d, exist_ok=True)

    # iterate through dirs
    for path_data in paths_data_raw:

        path_data = path_data.strip().rstrip('/')
        # get test_data_name for rescaling (if needed)
        #test_data_name = os.path.split(path_data)[-1]
        #test_data_name = '_'.join(test_data_name.split('_')[:3]) + '_'
        path_images_raw = os.path.join(path_data, image_format_path)
        path_labels = os.path.join(path_data, label_path_extra)

        # iterate through images, convert to 8-bit, and create masks
        im_files = os.listdir(path_images_raw)
        #m = defaultdict(list)
        for im_file in im_files:
            if not im_file.endswith('.tif'):
                continue

            #name_root_small = im_file.split('_')[-1].split('.')[0]
            name_root_full = im_file.split(imfile_prefix)[-1].split('.')[0]

            # create 8-bit image
            im_file_raw = os.path.join(path_images_raw, im_file)
            im_file_out = os.path.join(path_images_8bit, im_file)
            #im_file_out = os.path.join(path_images_8bit, test_data_name + name_root + '.tif')
            # convert to 8bit
            # m = calc_rescale(im_file_raw, m, percentiles=[2,98])
            # continue

            ####################
            # SET RESCALE TYPE
            #rescale_type = test_data_name.split('_')[1]
            if config.num_channels == 3:
                rescale_type = 'tot_3band'  #test_data_name.split('_')[1]
            else:
                rescale_type = 'tot_8band'  #test_data_name.split('_')[1]
            ####################

            if not os.path.isfile(im_file_out):
                apls_tools.convert_to_8Bit(im_file_raw,
                                           im_file_out,
                                           outputPixType='Byte',
                                           outputFormat='GTiff',
                                           rescale_type=rescale[rescale_type])

            if test:
                continue

            else:
                # determine mask output files
                label_name = geojson_prefix + name_root_full + '.geojson'
                label_file_tot = os.path.join(path_labels, label_name)
                output_raster = os.path.join(path_masks, im_file)
                print("\nname_root:", name_root_full)
                print("  output_mask_raster:", output_raster)

                # create masks
                mask, gdf_buffer = apls_tools.get_road_buffer(
                    label_file_tot,
                    im_file_out,
                    output_raster,
                    buffer_meters=buffer_meters,
                    burnValue=burnValue,
                    bufferRoundness=6,
                    plot_file=None,
                    figsize=(6, 6),  #(13,4),
                    fontsize=8,
                    dpi=200,
                    show_plot=False,
                    verbose=False)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('datasets', type=str, nargs='+')
    parser.add_argument('--training', action='store_true')
    args = parser.parse_args()
    buffer_meters = 2
    burnValue = 255

    path_apls = r'/wdata'
    test = not args.training
    path_outputs = os.path.join(path_apls, 'train' if not test else 'test', 'masks{}m'.format(buffer_meters))
    path_images_8bit = os.path.join(path_apls, 'train' if not test else 'test', 'images')
    for d in [path_outputs, path_images_8bit]:
        shutil.rmtree(d, ignore_errors=True)
        os.makedirs(d, exist_ok=True)

    for path_data in args.datasets:
        path_data = path_data.rstrip('/')
        test_data_name = os.path.split(path_data)[-1]
        test_data_name = '_'.join(test_data_name.split('_')[:3]) + '_'
        path_images_raw = os.path.join(path_data, 'RGB-PanSharpen')
        path_labels = os.path.join(path_data, 'geojson/spacenetroads')

        # iterate through images, convert to 8-bit, and create masks
        im_files = os.listdir(path_images_raw)
        m = defaultdict(list)
        for im_file in im_files:
            if not im_file.endswith('.tif'):
                continue

            name_root = im_file.split('_')[-1].split('.')[0]

            # create 8-bit image
            im_file_raw = os.path.join(path_images_raw, im_file)
            im_file_out = os.path.join(path_images_8bit, test_data_name + name_root + '.tif')
            # convert to 8bit

            # m = calc_rescale(im_file_raw, m, percentiles=[2,98])
            # continue
            rescale_type = test_data_name.split('_')[1]
            if not os.path.isfile(im_file_out):
                apls_tools.convert_to_8Bit(im_file_raw, im_file_out,
                                           outputPixType='Byte',
                                           outputFormat='GTiff',
                                           rescale_type=rescale[rescale_type],
                                           percentiles=[2,98])

            if test:
                continue
            # determine output files
            label_file = os.path.join(path_labels, 'spacenetroads_' + test_data_name + name_root + '.geojson')
            label_file_tot = os.path.join(path_labels, label_file)
            output_raster = os.path.join(path_outputs, test_data_name + name_root + '.png')

            print("\nname_root:", name_root)
            print("  output_raster:", output_raster)

            # create masks
            mask, gdf_buffer = apls_tools.get_road_buffer(label_file_tot, im_file_out,
                                                          output_raster,
                                                          buffer_meters=buffer_meters,
                                                          burnValue=burnValue,
                                                          bufferRoundness=6,
                                                          plot_file=None,
                                                          figsize= (6,6),  #(13,4),
                                                          fontsize=8,
                                                          dpi=200, show_plot=False,
                                                          verbose=False)

        for k, v in m.items():
            print(test_data_name, k, np.mean(v, axis=0))