Beispiel #1
0
def main(params: dict):
    """
    Identify the class to which each image belongs.
    :param params: (dict) Parameters found in the yaml config file.

    """
    since = time.time()

    # MANDATORY PARAMETERS
    img_dir_or_csv = get_key_def('img_dir_or_csv_file',
                                 params['inference'],
                                 expected_type=str)
    state_dict = get_key_def('state_dict_path', params['inference'])
    task = get_key_def('task', params['global'], expected_type=str)
    if task not in ['classification', 'segmentation']:
        raise ValueError(
            f'Task should be either "classification" or "segmentation". Got {task}'
        )
    model_name = get_key_def('model_name', params['global'],
                             expected_type=str).lower()
    num_classes = get_key_def('num_classes',
                              params['global'],
                              expected_type=int)
    num_bands = get_key_def('number_of_bands',
                            params['global'],
                            expected_type=int)
    chunk_size = get_key_def('chunk_size',
                             params['inference'],
                             default=512,
                             expected_type=int)
    BGR_to_RGB = get_key_def('BGR_to_RGB',
                             params['global'],
                             expected_type=bool)

    # OPTIONAL PARAMETERS
    dontcare_val = get_key_def("ignore_index",
                               params["training"],
                               default=-1,
                               expected_type=int)
    num_devices = get_key_def('num_gpus',
                              params['global'],
                              default=0,
                              expected_type=int)
    default_max_used_ram = 25
    max_used_ram = get_key_def('max_used_ram',
                               params['global'],
                               default=default_max_used_ram,
                               expected_type=int)
    max_used_perc = get_key_def('max_used_perc',
                                params['global'],
                                default=25,
                                expected_type=int)
    scale = get_key_def('scale_data',
                        params['global'],
                        default=[0, 1],
                        expected_type=List)
    debug = get_key_def('debug_mode',
                        params['global'],
                        default=False,
                        expected_type=bool)
    raster_to_vec = get_key_def('ras2vec', params['inference'], False)

    # benchmark (ie when gkpgs are inputted along with imagery)
    dontcare = get_key_def("ignore_index", params["training"], -1)
    targ_ids = get_key_def('target_ids',
                           params['sample'],
                           None,
                           expected_type=List)

    # SETTING OUTPUT DIRECTORY
    working_folder = Path(
        params['inference']['state_dict_path']).parent.joinpath(
            f'inference_{num_bands}bands')
    Path.mkdir(working_folder, parents=True, exist_ok=True)

    # mlflow logging
    mlflow_uri = get_key_def('mlflow_uri',
                             params['global'],
                             default=None,
                             expected_type=str)
    if mlflow_uri and not Path(mlflow_uri).is_dir():
        warnings.warn(f'Mlflow uri path is not valid: {mlflow_uri}')
        mlflow_uri = None
    # SETUP LOGGING
    import logging.config  # See: https://docs.python.org/2.4/lib/logging-config-fileformat.html
    if mlflow_uri:
        log_config_path = Path('utils/logging.conf').absolute()
        logfile = f'{working_folder}/info.log'
        logfile_debug = f'{working_folder}/debug.log'
        console_level_logging = 'INFO' if not debug else 'DEBUG'
        logging.config.fileConfig(log_config_path,
                                  defaults={
                                      'logfilename': logfile,
                                      'logfilename_debug': logfile_debug,
                                      'console_level': console_level_logging
                                  })

        # import only if mlflow uri is set
        from mlflow import log_params, set_tracking_uri, set_experiment, start_run, log_artifact, log_metrics
        if not Path(mlflow_uri).is_dir():
            logging.warning(
                f"Couldn't locate mlflow uri directory {mlflow_uri}. Directory will be created."
            )
            Path(mlflow_uri).mkdir()
        set_tracking_uri(mlflow_uri)
        exp_name = get_key_def('mlflow_experiment_name',
                               params['global'],
                               default='gdl-inference',
                               expected_type=str)
        set_experiment(f'{exp_name}/{working_folder.name}')
        run_name = get_key_def('mlflow_run_name',
                               params['global'],
                               default='gdl',
                               expected_type=str)
        start_run(run_name=run_name)
        log_params(params['global'])
        log_params(params['inference'])
    else:
        # set a console logger as default
        logging.basicConfig(level=logging.DEBUG)
        logging.info(
            'No logging folder set for mlflow. Logging will be limited to console'
        )

    if debug:
        logging.warning(
            f'Debug mode activated. Some debug features may mobilize extra disk space and '
            f'cause delays in execution.')

    # Assert that all items in target_ids are integers (ex.: to benchmark single-class model with multi-class labels)
    if targ_ids:
        for item in targ_ids:
            if not isinstance(item, int):
                raise ValueError(
                    f'Target id "{item}" in target_ids is {type(item)}, expected int.'
                )

    logging.info(f'Inferences will be saved to: {working_folder}\n\n')
    if not (0 <= max_used_ram <= 100):
        logging.warning(
            f'Max used ram parameter should be a percentage. Got {max_used_ram}. '
            f'Will set default value of {default_max_used_ram} %')
        max_used_ram = default_max_used_ram

    # AWS
    bucket = None
    bucket_file_cache = []
    bucket_name = get_key_def('bucket_name', params['global'])

    # list of GPU devices that are available and unused. If no GPUs, returns empty dict
    gpu_devices_dict = get_device_ids(num_devices,
                                      max_used_ram_perc=max_used_ram,
                                      max_used_perc=max_used_perc)
    if gpu_devices_dict:
        logging.info(
            f"Number of cuda devices requested: {num_devices}. Cuda devices available: {gpu_devices_dict}. "
            f"Using {list(gpu_devices_dict.keys())[0]}\n\n")
        device = torch.device(
            f'cuda:{list(range(len(gpu_devices_dict.keys())))[0]}')
    else:
        logging.warning(
            f"No Cuda device available. This process will only run on CPU")
        device = torch.device('cpu')

    # CONFIGURE MODEL
    num_classes_backgr = add_background_to_num_class(task, num_classes)
    model, loaded_checkpoint, model_name = net(model_name=model_name,
                                               num_bands=num_bands,
                                               num_channels=num_classes_backgr,
                                               dontcare_val=dontcare_val,
                                               num_devices=1,
                                               net_params=params,
                                               inference_state_dict=state_dict)
    try:
        model.to(device)
    except RuntimeError:
        logging.info(f"Unable to use device 0")
        device = torch.device(f'cuda' if gpu_devices_dict else 'cpu')
        model.to(device)

    # CREATE LIST OF INPUT IMAGES FOR INFERENCE
    list_img = list_input_images(img_dir_or_csv,
                                 bucket_name,
                                 glob_patterns=["*.tif", "*.TIF"])

    # VALIDATION: anticipate problems with imagery and label (if provided) before entering main for loop
    valid_gpkg_set = set()
    for info in tqdm(list_img, desc='Validating imagery'):
        # validate_raster(info['tif'], num_bands, meta_map)
        if 'gpkg' in info.keys(
        ) and info['gpkg'] and info['gpkg'] not in valid_gpkg_set:
            validate_num_classes(vector_file=info['gpkg'],
                                 num_classes=num_classes,
                                 attribute_name=info['attribute_name'],
                                 ignore_index=dontcare,
                                 target_ids=targ_ids)
            assert_crs_match(info['tif'], info['gpkg'])
            valid_gpkg_set.add(info['gpkg'])

    logging.info('Successfully validated imagery')
    if valid_gpkg_set:
        logging.info('Successfully validated label data for benchmarking')

    if task == 'classification':
        classifier(
            params, list_img, model, device, working_folder
        )  # FIXME: why don't we load from checkpoint in classification?

    elif task == 'segmentation':
        gdf_ = []
        gpkg_name_ = []

        # TODO: Add verifications?
        if bucket:
            bucket.download_file(
                loaded_checkpoint,
                "saved_model.pth.tar")  # TODO: is this still valid?
            model, _ = load_from_checkpoint("saved_model.pth.tar", model)
        else:
            model, _ = load_from_checkpoint(loaded_checkpoint, model)
        # LOOP THROUGH LIST OF INPUT IMAGES
        for info in tqdm(list_img,
                         desc='Inferring from images',
                         position=0,
                         leave=True):
            with start_run(run_name=Path(info['tif']).name, nested=True):
                img_name = Path(info['tif']).name
                local_gpkg = Path(
                    info['gpkg']
                ) if 'gpkg' in info.keys() and info['gpkg'] else None
                gpkg_name = local_gpkg.stem if local_gpkg else None
                if bucket:
                    local_img = f"Images/{img_name}"
                    bucket.download_file(info['tif'], local_img)
                    inference_image = f"Classified_Images/{img_name.split('.')[0]}_inference.tif"
                    if info['meta']:
                        if info['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(info['meta'])
                            bucket.download_file(info['meta'],
                                                 info['meta'].split('/')[-1])
                        info['meta'] = info['meta'].split('/')[-1]
                else:  # FIXME: else statement should support img['meta'] integration as well.
                    local_img = Path(info['tif'])
                    Path.mkdir(working_folder.joinpath(local_img.parent.name),
                               parents=True,
                               exist_ok=True)
                    inference_image = working_folder.joinpath(
                        local_img.parent.name,
                        f"{img_name.split('.')[0]}_inference.tif")
                temp_file = working_folder.joinpath(
                    local_img.parent.name, f"{img_name.split('.')[0]}.dat")
                raster = rasterio.open(local_img, 'r')
                logging.info(f'Reading original image: {raster.name}')
                inf_meta = raster.meta
                label = None
                if local_gpkg:
                    logging.info(f'Burning label as raster: {local_gpkg}')
                    local_img = clip_raster_with_gpkg(raster, local_gpkg)
                    raster.close()
                    raster = rasterio.open(local_img, 'r')
                    logging.info(f'Reading clipped image: {raster.name}')
                    inf_meta = raster.meta
                    label = vector_to_raster(
                        vector_file=local_gpkg,
                        input_image=raster,
                        out_shape=(inf_meta['height'], inf_meta['width']),
                        attribute_name=info['attribute_name'],
                        fill=0,  # background value in rasterized vector.
                        target_ids=targ_ids)
                    if debug:
                        logging.debug(
                            f'Unique values in loaded label as raster: {np.unique(label)}\n'
                            f'Shape of label as raster: {label.shape}')
                pred, gdf = segmentation(param=params,
                                         input_image=raster,
                                         label_arr=label,
                                         num_classes=num_classes_backgr,
                                         gpkg_name=gpkg_name,
                                         model=model,
                                         chunk_size=chunk_size,
                                         device=device,
                                         scale=scale,
                                         BGR_to_RGB=BGR_to_RGB,
                                         tp_mem=temp_file,
                                         debug=debug)
                if gdf is not None:
                    gdf_.append(gdf)
                    gpkg_name_.append(gpkg_name)
                if local_gpkg:
                    pixelMetrics = ComputePixelMetrics(label, pred,
                                                       num_classes_backgr)
                    log_metrics(pixelMetrics.update(pixelMetrics.iou))
                    log_metrics(pixelMetrics.update(pixelMetrics.dice))
                pred = pred[np.newaxis, :, :].astype(np.uint8)
                inf_meta.update({
                    "driver": "GTiff",
                    "height": pred.shape[1],
                    "width": pred.shape[2],
                    "count": pred.shape[0],
                    "dtype": 'uint8',
                    "compress": 'lzw'
                })
                logging.info(
                    f'Successfully inferred on {img_name}\nWriting to file: {inference_image}'
                )
                with rasterio.open(inference_image, 'w+', **inf_meta) as dest:
                    dest.write(pred)
                del pred
                try:
                    temp_file.unlink()
                except OSError as e:
                    logging.warning(f'File Error: {temp_file, e.strerror}')
                if raster_to_vec:
                    start_vec = time.time()
                    inference_vec = working_folder.joinpath(
                        local_img.parent.name,
                        f"{img_name.split('.')[0]}_inference.gpkg")
                    ras2vec(inference_image, inference_vec)
                    end_vec = time.time() - start_vec
                    logging.info(
                        'Vectorization completed in {:.0f}m {:.0f}s'.format(
                            end_vec // 60, end_vec % 60))

        if len(gdf_) >= 1:
            if not len(gdf_) == len(gpkg_name_):
                raise ValueError('benchmarking unable to complete')
            all_gdf = pd.concat(
                gdf_)  # Concatenate all geo data frame into one geo data frame
            all_gdf.reset_index(drop=True, inplace=True)
            gdf_x = gpd.GeoDataFrame(all_gdf)
            bench_gpkg = working_folder / "benchmark.gpkg"
            gdf_x.to_file(bench_gpkg, driver="GPKG", index=False)
            logging.info(
                f'Successfully wrote benchmark geopackage to: {bench_gpkg}')
        # log_artifact(working_folder)
    time_elapsed = time.time() - since
    logging.info('Inference Script completed in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
Beispiel #2
0
def main(params):
    """
    Training and validation datasets preparation.

    Process
    -------
    1. Read csv file and validate existence of all input files and GeoPackages.

    2. Do the following verifications:
        1. Assert number of bands found in raster is equal to desired number
           of bands.
        2. Check that `num_classes` is equal to number of classes detected in
           the specified attribute for each GeoPackage.
           Warning: this validation will not succeed if a Geopackage
                    contains only a subset of `num_classes` (e.g. 3 of 4).
        3. Assert Coordinate reference system between raster and gpkg match.

    3. Read csv file and for each line in the file, do the following:
        1. Read input image as array with utils.readers.image_reader_as_array().
            - If gpkg's extent is smaller than raster's extent,
              raster is clipped to gpkg's extent.
            - If gpkg's extent is bigger than raster's extent,
              gpkg is clipped to raster's extent.
        2. Convert GeoPackage vector information into the "label" raster with
           utils.utils.vector_to_raster(). The pixel value is determined by the
           attribute in the csv file.
        3. Create a new raster called "label" with the same properties as the
           input image.
        4. Read metadata and add to input as new bands (*more details to come*).
        5. Crop the arrays in smaller samples of the size `samples_size` of
           `your_conf.yaml`. Visual representation of this is provided at
            https://medium.com/the-downlinq/broad-area-satellite-imagery-semantic-segmentation-basiss-4a7ea2c8466f
        6. Write samples from input image and label into the "val", "trn" or
           "tst" hdf5 file, depending on the value contained in the csv file.
            Refer to samples_preparation().

    -------
    :param params: (dict) Parameters found in the yaml config file.
    """
    params['global']['git_hash'] = get_git_hash()
    now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    bucket_file_cache = []

    assert params['global'][
        'task'] == 'segmentation', f"images_to_samples.py isn't necessary when performing classification tasks"

    # SET BASIC VARIABLES AND PATHS. CREATE OUTPUT FOLDERS.
    bucket_name = get_key_def('bucket_name', params['global'])
    data_path = Path(params['global']['data_path'])
    Path.mkdir(data_path, exist_ok=True, parents=True)
    csv_file = params['sample']['prep_csv_file']
    val_percent = params['sample']['val_percent']
    samples_size = params["global"]["samples_size"]
    overlap = params["sample"]["overlap"]
    min_annot_perc = get_key_def('min_annotated_percent',
                                 params['sample']['sampling_method'],
                                 None,
                                 expected_type=int)
    num_bands = params['global']['number_of_bands']
    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(f'Debug mode activate. Execution may take longer...')

    final_samples_folder = None

    sample_path_name = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'

    # AWS
    if bucket_name:
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(bucket_name)
        bucket.download_file(csv_file, 'samples_prep.csv')
        list_data_prep = read_csv('samples_prep.csv')
        if data_path:
            final_samples_folder = data_path.joinpath("samples")
        else:
            final_samples_folder = "samples"
        samples_folder = sample_path_name

    else:
        list_data_prep = read_csv(csv_file)
        samples_folder = data_path.joinpath(sample_path_name)

    if samples_folder.is_dir():
        warnings.warn(
            f'Data path exists: {samples_folder}. Suffix will be added to directory name.'
        )
        samples_folder = Path(str(samples_folder) + '_' + now)
    else:
        tqdm.write(f'Writing samples to {samples_folder}')
    Path.mkdir(samples_folder, exist_ok=False
               )  # TODO: what if we want to append samples to existing hdf5?
    tqdm.write(f'Samples will be written to {samples_folder}\n\n')

    tqdm.write(f'\nSuccessfully read csv file: {Path(csv_file).stem}\n'
               f'Number of rows: {len(list_data_prep)}\n'
               f'Copying first entry:\n{list_data_prep[0]}\n')
    ignore_index = get_key_def('ignore_index', params['training'], -1)
    meta_map, metadata = get_key_def("meta_map", params["global"], {}), None

    # VALIDATION: (1) Assert num_classes parameters == num actual classes in gpkg and (2) check CRS match (tif and gpkg)
    valid_gpkg_set = set()
    for info in tqdm(list_data_prep, position=0):
        assert_num_bands(info['tif'], num_bands, meta_map)
        if info['gpkg'] not in valid_gpkg_set:
            gpkg_classes = validate_num_classes(
                info['gpkg'], params['global']['num_classes'],
                info['attribute_name'], ignore_index)
            assert_crs_match(info['tif'], info['gpkg'])
            valid_gpkg_set.add(info['gpkg'])

    if debug:
        # VALIDATION (debug only): Checking validity of features in vector files
        for info in tqdm(
                list_data_prep,
                position=0,
                desc=f"Checking validity of features in vector files"):
            invalid_features = validate_features_from_gpkg(
                info['gpkg'], info['attribute_name']
            )  # TODO: test this with invalid features.
            assert not invalid_features, f"{info['gpkg']}: Invalid geometry object(s) '{invalid_features}'"

    number_samples = {'trn': 0, 'val': 0, 'tst': 0}
    number_classes = 0

    class_prop = get_key_def('class_proportion',
                             params['sample']['sampling_method'],
                             None,
                             expected_type=dict)

    trn_hdf5, val_hdf5, tst_hdf5 = create_files_and_datasets(
        params, samples_folder)

    # Set dontcare (aka ignore_index) value
    dontcare = get_key_def(
        "ignore_index", params["training"],
        -1)  # TODO: deduplicate with train_segmentation, l300
    if dontcare == 0:
        warnings.warn(
            "The 'dontcare' value (or 'ignore_index') used in the loss function cannot be zero;"
            " all valid class indices should be consecutive, and start at 0. The 'dontcare' value"
            " will be remapped to -1 while loading the dataset, and inside the config from now on."
        )
        params["training"]["ignore_index"] = -1

    # creates pixel_classes dict and keys
    pixel_classes = {key: 0 for key in gpkg_classes}
    background_val = 0
    pixel_classes[background_val] = 0
    class_prop = validate_class_prop_dict(pixel_classes, class_prop)
    pixel_classes[dontcare] = 0

    # For each row in csv: (1) burn vector file to raster, (2) read input raster image, (3) prepare samples
    with tqdm(list_data_prep,
              position=0,
              leave=False,
              desc=f'Preparing samples') as _tqdm:
        for info in _tqdm:
            _tqdm.set_postfix(
                OrderedDict(tif=f'{Path(info["tif"]).stem}',
                            sample_size=params['global']['samples_size']))
            try:
                if bucket_name:
                    bucket.download_file(
                        info['tif'], "Images/" + info['tif'].split('/')[-1])
                    info['tif'] = "Images/" + info['tif'].split('/')[-1]
                    if info['gpkg'] not in bucket_file_cache:
                        bucket_file_cache.append(info['gpkg'])
                        bucket.download_file(info['gpkg'],
                                             info['gpkg'].split('/')[-1])
                    info['gpkg'] = info['gpkg'].split('/')[-1]
                    if info['meta']:
                        if info['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(info['meta'])
                            bucket.download_file(info['meta'],
                                                 info['meta'].split('/')[-1])
                        info['meta'] = info['meta'].split('/')[-1]

                with rasterio.open(info['tif'], 'r') as raster:
                    # 1. Read the input raster image
                    np_input_image, raster, dataset_nodata = image_reader_as_array(
                        input_image=raster,
                        clip_gpkg=info['gpkg'],
                        aux_vector_file=get_key_def('aux_vector_file',
                                                    params['global'], None),
                        aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                      params['global'], None),
                        aux_vector_ids=get_key_def('aux_vector_ids',
                                                   params['global'], None),
                        aux_vector_dist_maps=get_key_def(
                            'aux_vector_dist_maps', params['global'], True),
                        aux_vector_dist_log=get_key_def(
                            'aux_vector_dist_log', params['global'], True),
                        aux_vector_scale=get_key_def('aux_vector_scale',
                                                     params['global'], None))

                    # 2. Burn vector file in a raster file
                    np_label_raster = vector_to_raster(
                        vector_file=info['gpkg'],
                        input_image=raster,
                        out_shape=np_input_image.shape[:2],
                        attribute_name=info['attribute_name'],
                        fill=background_val
                    )  # background value in rasterized vector.

                    if dataset_nodata is not None:
                        # 3. Set ignore_index value in label array where nodata in raster (only if nodata across all bands)
                        np_label_raster[dataset_nodata] = dontcare

                if debug:
                    out_meta = raster.meta.copy()
                    np_image_debug = np_input_image.transpose(2, 0, 1).astype(
                        out_meta['dtype'])
                    out_meta.update({
                        "driver": "GTiff",
                        "height": np_image_debug.shape[1],
                        "width": np_image_debug.shape[2]
                    })
                    out_tif = samples_folder / f"np_input_image_{_tqdm.n}.tif"
                    print(f"DEBUG: writing clipped raster to {out_tif}")
                    with rasterio.open(out_tif, "w", **out_meta) as dest:
                        dest.write(np_image_debug)

                    out_meta = raster.meta.copy()
                    np_label_debug = np.expand_dims(
                        np_label_raster,
                        axis=2).transpose(2, 0, 1).astype(out_meta['dtype'])
                    out_meta.update({
                        "driver": "GTiff",
                        "height": np_label_debug.shape[1],
                        "width": np_label_debug.shape[2],
                        'count': 1
                    })
                    out_tif = samples_folder / f"np_label_rasterized_{_tqdm.n}.tif"
                    print(f"DEBUG: writing final rasterized gpkg to {out_tif}")
                    with rasterio.open(out_tif, "w", **out_meta) as dest:
                        dest.write(np_label_debug)

                # Mask the zeros from input image into label raster.
                if params['sample']['mask_reference']:
                    np_label_raster = mask_image(np_input_image,
                                                 np_label_raster)

                if info['dataset'] == 'trn':
                    out_file = trn_hdf5
                elif info['dataset'] == 'tst':
                    out_file = tst_hdf5
                else:
                    raise ValueError(
                        f"Dataset value must be trn or tst. Provided value is {info['dataset']}"
                    )
                val_file = val_hdf5

                metadata = add_metadata_from_raster_to_sample(
                    sat_img_arr=np_input_image,
                    raster_handle=raster,
                    meta_map=meta_map,
                    raster_info=info)
                # Save label's per class pixel count to image metadata
                metadata['source_label_bincount'] = {
                    class_num: count
                    for class_num, count in enumerate(
                        np.bincount(np_label_raster.clip(min=0).flatten()))
                    if count > 0
                }  # TODO: add this to add_metadata_from[...] function?

                np_label_raster = np.reshape(
                    np_label_raster,
                    (np_label_raster.shape[0], np_label_raster.shape[1], 1))
                # 3. Prepare samples!
                number_samples, number_classes = samples_preparation(
                    in_img_array=np_input_image,
                    label_array=np_label_raster,
                    sample_size=samples_size,
                    overlap=overlap,
                    samples_count=number_samples,
                    num_classes=number_classes,
                    samples_file=out_file,
                    val_percent=val_percent,
                    val_sample_file=val_file,
                    dataset=info['dataset'],
                    pixel_classes=pixel_classes,
                    image_metadata=metadata,
                    dontcare=dontcare,
                    min_annot_perc=min_annot_perc,
                    class_prop=class_prop)

                _tqdm.set_postfix(OrderedDict(number_samples=number_samples))
                out_file.flush()
            except OSError as e:
                warnings.warn(
                    f'An error occurred while preparing samples with "{Path(info["tif"]).stem}" (tiff) and '
                    f'{Path(info["gpkg"]).stem} (gpkg). Error: "{e}"')
                continue

    trn_hdf5.close()
    val_hdf5.close()
    tst_hdf5.close()

    pixel_total = 0
    # adds up the number of pixels for each class in pixel_classes dict
    for i in pixel_classes:
        pixel_total += pixel_classes[i]

    # prints the proportion of pixels of each class for the samples created
    for i in pixel_classes:
        prop = round((pixel_classes[i] / pixel_total) *
                     100, 1) if pixel_total > 0 else 0
        print('Pixels from class', i, ':', prop, '%')

    print("Number of samples created: ", number_samples)

    if bucket_name and final_samples_folder:
        print('Transfering Samples to the bucket')
        bucket.upload_file(samples_folder + "/trn_samples.hdf5",
                           final_samples_folder + '/trn_samples.hdf5')
        bucket.upload_file(samples_folder + "/val_samples.hdf5",
                           final_samples_folder + '/val_samples.hdf5')
        bucket.upload_file(samples_folder + "/tst_samples.hdf5",
                           final_samples_folder + '/tst_samples.hdf5')

    print("End of process")
Beispiel #3
0
def create_csv():
    """
    Creates samples from the input images for the pixel_inventory function

    """
    prep_csv_path = params['sample']['prep_csv_file']
    dist_samples = params['sample']['samples_dist']
    sample_size = params['global']['samples_size']
    data_path = params['global']['data_path']
    Path.mkdir(Path(data_path), exist_ok=True)
    num_classes = params['global']['num_classes']
    data_prep_csv = read_csv(prep_csv_path)

    csv_prop_data = params['global']['data_path'] + '/prop_data.csv'
    if os.path.isfile(csv_prop_data):
        os.remove(csv_prop_data)

    with tqdm(data_prep_csv) as _tqdm:
        for info in _tqdm:

            _tqdm.set_postfix(OrderedDict(file=f'{info["tif"]}', sample_size=params['global']['samples_size']))

            # Validate the number of class in the vector file
            validate_num_classes(info['gpkg'], num_classes, info['attribute_name'])

            assert os.path.isfile(info['tif']), f"could not open raster file at {info['tif']}"
            with rasterio.open(info['tif'], 'r') as raster:

                # Burn vector file in a raster file
                np_label_raster = vector_to_raster(vector_file=info['gpkg'],
                                                   input_image=raster,
                                                   attribute_name=info['attribute_name'],
                                                   fill=get_key_def('ignore_idx', get_key_def('training', params, {}),
                                                                    0))

                # Read the input raster image
                np_input_image = image_reader_as_array(input_image=raster,
                                                       aux_vector_file=get_key_def('aux_vector_file', params['global'],
                                                                                   None),
                                                       aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                                                     params['global'], None),
                                                       aux_vector_ids=get_key_def('aux_vector_ids', params['global'],
                                                                                  None),
                                                       aux_vector_dist_maps=get_key_def('aux_vector_dist_maps',
                                                                                        params['global'], True),
                                                       aux_vector_dist_log=get_key_def('aux_vector_dist_log',
                                                                                       params['global'], True),
                                                       aux_vector_scale=get_key_def('aux_vector_scale',
                                                                                    params['global'], None))
                # Mask the zeros from input image into label raster.
                if params['sample']['mask_reference']:
                    np_label_raster = images_to_samples.mask_image(np_input_image, np_label_raster)

                np_label_raster = np.reshape(np_label_raster, (np_label_raster.shape[0], np_label_raster.shape[1], 1))

                h, w, num_bands = np_input_image.shape

                # half tile padding
                half_tile = int(sample_size / 2)
                pad_label_array = np.pad(np_label_raster, ((half_tile, half_tile), (half_tile, half_tile), (0, 0)),
                                         mode='constant')

                for row in range(0, h, dist_samples):
                    for column in range(0, w, dist_samples):
                        target = np.squeeze(pad_label_array[row:row + sample_size, column:column + sample_size, :], axis=2)

                        pixel_inventory(target, sample_size, params['global']['num_classes'] + 1,
                                        params['global']['data_path'], info['dataset'])
Beispiel #4
0
def main(params: dict) -> None:
    """
    Function to manage details about the inference on segmentation task.
    1. Read the parameters from the config given.
    2. Read and load the state dict from the previous training or the given one.
    3. Make the inference on the data specifies in the config.
    -------
    :param params: (dict) Parameters found in the yaml config file.
    """
    # since = time.time()

    # PARAMETERS
    mode = get_key_def('mode', params, expected_type=str)
    task = get_key_def('task_name', params['task'], expected_type=str)
    model_name = get_key_def('model_name', params['model'], expected_type=str).lower()
    num_classes = len(get_key_def('classes_dict', params['dataset']).keys())
    modalities = read_modalities(get_key_def('modalities', params['dataset'], expected_type=str))
    BGR_to_RGB = get_key_def('BGR_to_RGB', params['dataset'], expected_type=bool)
    num_bands = len(modalities)
    debug = get_key_def('debug', params, default=False, expected_type=bool)
    # SETTING OUTPUT DIRECTORY
    try:
        state_dict = Path(params['inference']['state_dict_path']).resolve(strict=True)
    except FileNotFoundError:
        logging.info(
            f"\nThe state dict path directory '{params['inference']['state_dict_path']}' don't seem to be find," +
            f"we will try to locate a state dict path in the '{params['general']['save_weights_dir']}' " +
            f"specify during the training phase"
        )
        try:
            state_dict = Path(params['general']['save_weights_dir']).resolve(strict=True)
        except FileNotFoundError:
            raise logging.critical(
                f"\nThe state dict path directory '{params['general']['save_weights_dir']}'" +
                f" don't seem to be find either, please specify the path to a state dict"
            )
    # TODO add more detail in the parent folder
    working_folder = state_dict.parent.joinpath(f'inference_{num_bands}bands')
    logging.info("\nThe state dict path directory used '{}'".format(working_folder))
    Path.mkdir(working_folder, parents=True, exist_ok=True)

    # LOGGING PARAMETERS TODO put option not just mlflow
    experiment_name = get_key_def('project_name', params['general'], default='gdl-training')
    try:
        tracker_uri = get_key_def('uri', params['tracker'], default=None, expected_type=str)
        Path(tracker_uri).mkdir(exist_ok=True)
        run_name = get_key_def('run_name', params['tracker'], default='gdl')  # TODO change for something meaningful
        run_name = '{}_{}_{}'.format(run_name, mode, task)
        logging.info(f'\nInference and log files will be saved to: {working_folder}')
        # TODO change to fit whatever inport
        from mlflow import log_params, set_tracking_uri, set_experiment, start_run, log_artifact, log_metrics
        # tracking path + parameters logging
        set_tracking_uri(tracker_uri)
        set_experiment(experiment_name)
        start_run(run_name=run_name)
        log_params(dict_path(params, 'general'))
        log_params(dict_path(params, 'dataset'))
        log_params(dict_path(params, 'data'))
        log_params(dict_path(params, 'model'))
        log_params(dict_path(params, 'inference'))
    # meaning no logging tracker as been assigned or it doesnt exist in config/logging
    except ConfigKeyError:
        logging.info(
            "\nNo logging tracker as been assigned or the yaml config doesnt exist in 'config/tracker'."
            "\nNo tracker file will be save in that case."
        )

    # MANDATORY PARAMETERS
    img_dir_or_csv = get_key_def(
        'img_dir_or_csv_file', params['inference'], default=params['general']['raw_data_csv'], expected_type=str
    )
    if not (Path(img_dir_or_csv).is_dir() or Path(img_dir_or_csv).suffix == '.csv'):
        raise logging.critical(
            FileNotFoundError(
                f'\nCouldn\'t locate .csv file or directory "{img_dir_or_csv}" containing imagery for inference'
            )
        )
    # load the checkpoint
    try:
        # Sort by modification time (mtime) descending
        sorted_by_mtime_descending = sorted(
            [os.path.join(state_dict, x) for x in os.listdir(state_dict)], key=lambda t: -os.stat(t).st_mtime
        )
        last_checkpoint_save = find_first_file('checkpoint.pth.tar', sorted_by_mtime_descending)
        if last_checkpoint_save is None:
            raise FileNotFoundError
        # change the state_dict
        state_dict = last_checkpoint_save
    except FileNotFoundError as e:
        logging.error(f"\nNo file name 'checkpoint.pth.tar' as been found at '{state_dict}'")
        raise e

    task = get_key_def('task_name', params['task'], expected_type=str)
    # TODO change it next version for all task
    if task not in ['classification', 'segmentation']:
        raise logging.critical(
            ValueError(f'\nTask should be either "classification" or "segmentation". Got {task}')
        )

    # OPTIONAL PARAMETERS
    dontcare_val = get_key_def("ignore_index", params["training"], default=-1, expected_type=int)
    num_devices = get_key_def('num_gpus', params['training'], default=0, expected_type=int)
    default_max_used_ram = 25
    max_used_ram = get_key_def('max_used_ram', params['training'], default=default_max_used_ram, expected_type=int)
    max_used_perc = get_key_def('max_used_perc', params['training'], default=25, expected_type=int)
    scale = get_key_def('scale_data', params['augmentation'], default=[0, 1], expected_type=ListConfig)
    raster_to_vec = get_key_def('ras2vec', params['inference'], False) # FIXME not implemented with hydra

    # benchmark (ie when gkpgs are inputted along with imagery)
    dontcare = get_key_def("ignore_index", params["training"], -1)
    targ_ids = None  # TODO get_key_def('target_ids', params['sample'], None, expected_type=List)

    if debug:
        logging.warning(f'\nDebug mode activated. Some debug features may mobilize extra disk space and '
                        f'cause delays in execution.')

    # Assert that all items in target_ids are integers (ex.: to benchmark single-class model with multi-class labels)
    if targ_ids:
        for item in targ_ids:
            if not isinstance(item, int):
                raise ValueError(f'\nTarget id "{item}" in target_ids is {type(item)}, expected int.')

    logging.info(f'\nInferences will be saved to: {working_folder}\n\n')
    if not (0 <= max_used_ram <= 100):
        logging.warning(f'\nMax used ram parameter should be a percentage. Got {max_used_ram}. '
                        f'Will set default value of {default_max_used_ram} %')
        max_used_ram = default_max_used_ram

    # AWS
    bucket = None
    bucket_file_cache = []
    bucket_name = get_key_def('bucket_name', params['AWS'])

    # list of GPU devices that are available and unused. If no GPUs, returns empty dict
    gpu_devices_dict = get_device_ids(num_devices,
                                      max_used_ram_perc=max_used_ram,
                                      max_used_perc=max_used_perc)
    if gpu_devices_dict:
        chunk_size = calc_inference_chunk_size(gpu_devices_dict=gpu_devices_dict, max_pix_per_mb_gpu=50)
        logging.info(f"\nNumber of cuda devices requested: {num_devices}. "
                     f"\nCuda devices available: {gpu_devices_dict}. "
                     f"\nUsing {list(gpu_devices_dict.keys())[0]}\n\n")
        device = torch.device(f'cuda:{list(range(len(gpu_devices_dict.keys())))[0]}')
    else:
        chunk_size = get_key_def('chunk_size', params['inference'], default=512, expected_type=int)
        logging.warning(f"\nNo Cuda device available. This process will only run on CPU")
        device = torch.device('cpu')

    # CONFIGURE MODEL
    num_classes_backgr = add_background_to_num_class(task, num_classes)
    model, loaded_checkpoint, model_name = net(model_name=model_name,
                                               num_bands=num_bands,
                                               num_channels=num_classes_backgr,
                                               dontcare_val=dontcare_val,
                                               num_devices=1,
                                               net_params=params,
                                               inference_state_dict=state_dict)
    try:
        model.to(device)
    except RuntimeError:
        logging.info(f"\nUnable to use device. Trying device 0")
        device = torch.device(f'cuda' if gpu_devices_dict else 'cpu')
        model.to(device)

    # CREATE LIST OF INPUT IMAGES FOR INFERENCE
    try:
        # check if the data folder exist
        raw_data_dir = get_key_def('raw_data_dir', params['dataset'])
        my_data_path = Path(raw_data_dir).resolve(strict=True)
        logging.info("\nImage directory used '{}'".format(my_data_path))
        data_path = Path(my_data_path)
    except FileNotFoundError:
        raw_data_dir = get_key_def('raw_data_dir', params['dataset'])
        raise logging.critical(
            "\nImage directory '{}' doesn't exist, please change the path".format(raw_data_dir)
        )
    list_img = list_input_images(
        img_dir_or_csv, bucket_name, glob_patterns=["*.tif", "*.TIF"], in_case_of_path=str(data_path)
    )

    # VALIDATION: anticipate problems with imagery and label (if provided) before entering main for loop
    valid_gpkg_set = set()
    for info in tqdm(list_img, desc='Validating imagery'):
        # validate_raster(info['tif'], num_bands, meta_map)
        if 'gpkg' in info.keys() and info['gpkg'] and info['gpkg'] not in valid_gpkg_set:
            validate_num_classes(vector_file=info['gpkg'],
                                 num_classes=num_classes,
                                 attribute_name=info['attribute_name'],
                                 ignore_index=dontcare,
                                 target_ids=targ_ids)
            assert_crs_match(info['tif'], info['gpkg'])
            valid_gpkg_set.add(info['gpkg'])

    logging.info('\nSuccessfully validated imagery')
    if valid_gpkg_set:
        logging.info('\nSuccessfully validated label data for benchmarking')

    if task == 'classification':
        classifier(params, list_img, model, device,
                   working_folder)  # FIXME: why don't we load from checkpoint in classification?

    elif task == 'segmentation':
        gdf_ = []
        gpkg_name_ = []

        # TODO: Add verifications?
        if bucket:
            bucket.download_file(loaded_checkpoint, "saved_model.pth.tar")  # TODO: is this still valid?
            model, _ = load_from_checkpoint("saved_model.pth.tar", model)
        else:
            model, _ = load_from_checkpoint(loaded_checkpoint, model)

        # Save tracking TODO put option not just mlflow
        if 'tracker_uri' in locals() and 'run_name' in locals():
            mode = get_key_def('mode', params, expected_type=str)
            task = get_key_def('task_name', params['task'], expected_type=str)
            run_name = '{}_{}_{}'.format(run_name, mode, task)
            # tracking path + parameters logging
            set_tracking_uri(tracker_uri)
            set_experiment(experiment_name)
            start_run(run_name=run_name)
            log_params(dict_path(params, 'inference'))
            log_params(dict_path(params, 'dataset'))
            log_params(dict_path(params, 'model'))

        # LOOP THROUGH LIST OF INPUT IMAGES
        for info in tqdm(list_img, desc='Inferring from images', position=0, leave=True):
            img_name = Path(info['tif']).name
            local_gpkg = Path(info['gpkg']) if 'gpkg' in info.keys() and info['gpkg'] else None
            gpkg_name = local_gpkg.stem if local_gpkg else None
            if bucket:
                local_img = f"Images/{img_name}"
                bucket.download_file(info['tif'], local_img)
                inference_image = f"Classified_Images/{img_name.split('.')[0]}_inference.tif"
                if info['meta']:
                    if info['meta'] not in bucket_file_cache:
                        bucket_file_cache.append(info['meta'])
                        bucket.download_file(info['meta'], info['meta'].split('/')[-1])
                    info['meta'] = info['meta'].split('/')[-1]
            else:  # FIXME: else statement should support img['meta'] integration as well.
                local_img = Path(info['tif'])
                Path.mkdir(working_folder.joinpath(local_img.parent.name), parents=True, exist_ok=True)
                inference_image = working_folder.joinpath(local_img.parent.name,
                                                          f"{img_name.split('.')[0]}_inference.tif")
            temp_file = working_folder.joinpath(local_img.parent.name, f"{img_name.split('.')[0]}.dat")
            raster = rasterio.open(local_img, 'r')
            logging.info(f'\nReading original image: {raster.name}')
            inf_meta = raster.meta
            label = None
            if local_gpkg:
                logging.info(f'\nBurning label as raster: {local_gpkg}')
                local_img = clip_raster_with_gpkg(raster, local_gpkg)
                raster.close()
                raster = rasterio.open(local_img, 'r')
                logging.info(f'\nReading clipped image: {raster.name}')
                inf_meta = raster.meta
                label = vector_to_raster(vector_file=local_gpkg,
                                         input_image=raster,
                                         out_shape=(inf_meta['height'], inf_meta['width']),
                                         attribute_name=info['attribute_name'],
                                         fill=0,  # background value in rasterized vector.
                                         target_ids=targ_ids)
                if debug:
                    logging.debug(f'\nUnique values in loaded label as raster: {np.unique(label)}\n'
                                  f'Shape of label as raster: {label.shape}')
            pred, gdf = segmentation(param=params,
                                     input_image=raster,
                                     label_arr=label,
                                     num_classes=num_classes_backgr,
                                     gpkg_name=gpkg_name,
                                     model=model,
                                     chunk_size=chunk_size,
                                     device=device,
                                     scale=scale,
                                     BGR_to_RGB=BGR_to_RGB,
                                     tp_mem=temp_file,
                                     debug=debug)
            if gdf is not None:
                gdf_.append(gdf)
                gpkg_name_.append(gpkg_name)
            if local_gpkg and 'tracker_uri' in locals():
                pixelMetrics = ComputePixelMetrics(label, pred, num_classes_backgr)
                log_metrics(pixelMetrics.update(pixelMetrics.iou))
                log_metrics(pixelMetrics.update(pixelMetrics.dice))
            pred = pred[np.newaxis, :, :].astype(np.uint8)
            inf_meta.update({"driver": "GTiff",
                             "height": pred.shape[1],
                             "width": pred.shape[2],
                             "count": pred.shape[0],
                             "dtype": 'uint8',
                             "compress": 'lzw'})
            logging.info(f'\nSuccessfully inferred on {img_name}\nWriting to file: {inference_image}')
            with rasterio.open(inference_image, 'w+', **inf_meta) as dest:
                dest.write(pred)
            del pred
            try:
                temp_file.unlink()
            except OSError as e:
                logging.warning(f'File Error: {temp_file, e.strerror}')
            if raster_to_vec:
                start_vec = time.time()
                inference_vec = working_folder.joinpath(local_img.parent.name,
                                                        f"{img_name.split('.')[0]}_inference.gpkg")
                ras2vec(inference_image, inference_vec)
                end_vec = time.time() - start_vec
                logging.info('Vectorization completed in {:.0f}m {:.0f}s'.format(end_vec // 60, end_vec % 60))

        if len(gdf_) >= 1:
            if not len(gdf_) == len(gpkg_name_):
                raise logging.critical(ValueError('\nbenchmarking unable to complete'))
            all_gdf = pd.concat(gdf_)  # Concatenate all geo data frame into one geo data frame
            all_gdf.reset_index(drop=True, inplace=True)
            gdf_x = gpd.GeoDataFrame(all_gdf)
            bench_gpkg = working_folder / "benchmark.gpkg"
            gdf_x.to_file(bench_gpkg, driver="GPKG", index=False)
            logging.info(f'\nSuccessfully wrote benchmark geopackage to: {bench_gpkg}')
        # log_artifact(working_folder)
def main(params):
    """
    Training and validation datasets preparation.
    :param params: (dict) Parameters found in the yaml config file.

    """
    now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    bucket_file_cache = []

    assert params['global'][
        'task'] == 'segmentation', f"images_to_samples.py isn't necessary when performing classification tasks"

    # SET BASIC VARIABLES AND PATHS. CREATE OUTPUT FOLDERS.
    bucket_name = params['global']['bucket_name']
    data_path = Path(params['global']['data_path'])
    Path.mkdir(data_path, exist_ok=True, parents=True)
    csv_file = params['sample']['prep_csv_file']
    val_percent = params['sample']['val_percent']
    samples_size = params["global"]["samples_size"]
    overlap = params["sample"]["overlap"]
    min_annot_perc = params['sample']['sampling']['map']
    num_bands = params['global']['number_of_bands']
    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(f'Debug mode activate. Execution may take longer...')

    final_samples_folder = None
    if bucket_name:
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(bucket_name)
        bucket.download_file(csv_file, 'samples_prep.csv')
        list_data_prep = read_csv('samples_prep.csv')
        if data_path:
            final_samples_folder = os.path.join(data_path, "samples")
        else:
            final_samples_folder = "samples"
        samples_folder = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'  # TODO: validate this is preferred name structure

    else:
        list_data_prep = read_csv(csv_file)
        samples_folder = data_path.joinpath(
            f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'
        )

    if samples_folder.is_dir():
        warnings.warn(
            f'Data path exists: {samples_folder}. Suffix will be added to directory name.'
        )
        samples_folder = Path(str(samples_folder) + '_' + now)
    else:
        tqdm.write(f'Writing samples to {samples_folder}')
    Path.mkdir(samples_folder, exist_ok=False
               )  # FIXME: what if we want to append samples to existing hdf5?
    tqdm.write(f'Samples will be written to {samples_folder}\n\n')

    tqdm.write(
        f'\nSuccessfully read csv file: {Path(csv_file).stem}\nNumber of rows: {len(list_data_prep)}\nCopying first entry:\n{list_data_prep[0]}\n'
    )
    ignore_index = get_key_def('ignore_index', params['training'], -1)

    for info in tqdm(list_data_prep,
                     position=0,
                     desc=f'Asserting existence of tif and gpkg files in csv'):
        assert Path(info['tif']).is_file(
        ), f'Could not locate "{info["tif"]}". Make sure file exists in this directory.'
        assert Path(info['gpkg']).is_file(
        ), f'Could not locate "{info["gpkg"]}". Make sure file exists in this directory.'
    if debug:
        for info in tqdm(
                list_data_prep,
                position=0,
                desc=f"Validating presence of {params['global']['num_classes']} "
                f"classes in attribute \"{info['attribute_name']}\" for vector "
                f"file \"{Path(info['gpkg']).stem}\""):
            validate_num_classes(info['gpkg'], params['global']['num_classes'],
                                 info['attribute_name'], ignore_index)
        with tqdm(list_data_prep,
                  position=0,
                  desc=f"Checking validity of features in vector files"
                  ) as _tqdm:
            invalid_features = {}
            for info in _tqdm:
                # Extract vector features to burn in the raster image
                with fiona.open(
                        info['gpkg'],
                        'r') as src:  # TODO: refactor as independent function
                    lst_vector = [vector for vector in src]
                shapes = lst_ids(list_vector=lst_vector,
                                 attr_name=info['attribute_name'])
                for index, item in enumerate(
                        tqdm([v for vecs in shapes.values() for v in vecs],
                             leave=False,
                             position=1)):
                    # geom must be a valid GeoJSON geometry type and non-empty
                    geom, value = item
                    geom = getattr(geom, '__geo_interface__', None) or geom
                    if not is_valid_geom(geom):
                        gpkg_stem = str(Path(info['gpkg']).stem)
                        if gpkg_stem not in invalid_features.keys(
                        ):  # create key with name of gpkg
                            invalid_features[gpkg_stem] = []
                        if lst_vector[index]["id"] not in invalid_features[
                                gpkg_stem]:  # ignore feature is already appended
                            invalid_features[gpkg_stem].append(
                                lst_vector[index]["id"])
            assert len(
                invalid_features.values()
            ) == 0, f'Invalid geometry object(s) for "gpkg:ids": \"{invalid_features}\"'

    number_samples = {'trn': 0, 'val': 0, 'tst': 0}
    number_classes = 0

    # 'sampling' ordereddict validation
    check_sampling_dict()

    pixel_classes = {}
    # creates pixel_classes dict and keys
    for i in range(0, params['global']['num_classes'] + 1):
        pixel_classes.update({i: 0})
    pixel_classes.update(
        {ignore_index: 0}
    )  # FIXME: pixel_classes dict needs to be populated with classes obtained from target

    trn_hdf5, val_hdf5, tst_hdf5 = create_files_and_datasets(
        params, samples_folder)

    # For each row in csv: (1) burn vector file to raster, (2) read input raster image, (3) prepare samples
    with tqdm(list_data_prep,
              position=0,
              leave=False,
              desc=f'Preparing samples') as _tqdm:
        for info in _tqdm:
            _tqdm.set_postfix(
                OrderedDict(tif=f'{Path(info["tif"]).stem}',
                            sample_size=params['global']['samples_size']))
            try:
                if bucket_name:
                    bucket.download_file(
                        info['tif'], "Images/" + info['tif'].split('/')[-1])
                    info['tif'] = "Images/" + info['tif'].split('/')[-1]
                    if info['gpkg'] not in bucket_file_cache:
                        bucket_file_cache.append(info['gpkg'])
                        bucket.download_file(info['gpkg'],
                                             info['gpkg'].split('/')[-1])
                    info['gpkg'] = info['gpkg'].split('/')[-1]
                    if info['meta']:
                        if info['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(info['meta'])
                            bucket.download_file(info['meta'],
                                                 info['meta'].split('/')[-1])
                        info['meta'] = info['meta'].split('/')[-1]

                with rasterio.open(info['tif'], 'r') as raster:
                    # Burn vector file in a raster file
                    np_label_raster = vector_to_raster(
                        vector_file=info['gpkg'],
                        input_image=raster,
                        attribute_name=info['attribute_name'],
                        fill=get_key_def('ignore_idx',
                                         get_key_def('training', params, {}),
                                         0))
                    # Read the input raster image
                    np_input_image = image_reader_as_array(
                        input_image=raster,
                        scale=get_key_def('scale_data', params['global'],
                                          None),
                        aux_vector_file=get_key_def('aux_vector_file',
                                                    params['global'], None),
                        aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                      params['global'], None),
                        aux_vector_ids=get_key_def('aux_vector_ids',
                                                   params['global'], None),
                        aux_vector_dist_maps=get_key_def(
                            'aux_vector_dist_maps', params['global'], True),
                        aux_vector_dist_log=get_key_def(
                            'aux_vector_dist_log', params['global'], True),
                        aux_vector_scale=get_key_def('aux_vector_scale',
                                                     params['global'], None))

                # Mask the zeros from input image into label raster.
                if params['sample']['mask_reference']:
                    np_label_raster = mask_image(np_input_image,
                                                 np_label_raster)

                if info['dataset'] == 'trn':
                    out_file = trn_hdf5
                    val_file = val_hdf5
                elif info['dataset'] == 'tst':
                    out_file = tst_hdf5
                else:
                    raise ValueError(
                        f"Dataset value must be trn or val or tst. Provided value is {info['dataset']}"
                    )

                meta_map, metadata = get_key_def("meta_map", params["global"],
                                                 {}), None
                if info['meta'] is not None and isinstance(
                        info['meta'], str) and Path(info['meta']).is_file():
                    metadata = read_parameters(info['meta'])

                # FIXME: think this through. User will have to calculate the total number of bands including meta layers and
                #  specify it in yaml. Is this the best approach? What if metalayers are added on the fly ?
                input_band_count = np_input_image.shape[
                    2] + MetaSegmentationDataset.get_meta_layer_count(meta_map)
                # FIXME: could this assert be done before getting into this big for loop?
                assert input_band_count == num_bands, \
                    f"The number of bands in the input image ({input_band_count}) and the parameter" \
                    f"'number_of_bands' in the yaml file ({params['global']['number_of_bands']}) should be identical"

                np_label_raster = np.reshape(
                    np_label_raster,
                    (np_label_raster.shape[0], np_label_raster.shape[1], 1))
                number_samples, number_classes = samples_preparation(
                    np_input_image, np_label_raster, samples_size, overlap,
                    number_samples, number_classes, out_file, val_percent,
                    val_file, info['dataset'], pixel_classes, metadata)

                _tqdm.set_postfix(OrderedDict(number_samples=number_samples))
                out_file.flush()
            except Exception as e:
                warnings.warn(
                    f'An error occurred while preparing samples with "{Path(info["tif"]).stem}" (tiff) and '
                    f'{Path(info["gpkg"]).stem} (gpkg). Error: "{e}"')
                continue

    trn_hdf5.close()
    val_hdf5.close()
    tst_hdf5.close()

    pixel_total = 0
    # adds up the number of pixels for each class in pixel_classes dict
    for i in pixel_classes:
        pixel_total += pixel_classes[i]

    # prints the proportion of pixels of each class for the samples created
    for i in pixel_classes:
        print('Pixels from class', i, ':',
              round((pixel_classes[i] / pixel_total) * 100, 1), '%')

    print("Number of samples created: ", number_samples)

    if bucket_name and final_samples_folder:
        print('Transfering Samples to the bucket')
        bucket.upload_file(samples_folder + "/trn_samples.hdf5",
                           final_samples_folder + '/trn_samples.hdf5')
        bucket.upload_file(samples_folder + "/val_samples.hdf5",
                           final_samples_folder + '/val_samples.hdf5')
        bucket.upload_file(samples_folder + "/tst_samples.hdf5",
                           final_samples_folder + '/tst_samples.hdf5')

    print("End of process")
def main(params):
    """
    Dataset preparation (trn, val, tst).
    :param params: (dict) Parameters found in the yaml config file.

    """
    assert params['global']['task'] == 'segmentation', \
        f"sample_creation.py isn't necessary when performing classification tasks"
    num_classes = get_key_def('num_classes',
                              params['global'],
                              expected_type=int)
    num_bands = get_key_def('number_of_bands',
                            params['global'],
                            expected_type=int)
    debug = get_key_def('debug_mode', params['global'], False)
    targ_ids = get_key_def('target_ids',
                           params['sample'],
                           None,
                           expected_type=List)

    # SET BASIC VARIABLES AND PATHS. CREATE OUTPUT FOLDERS.
    val_percent = params['sample']['val_percent']
    samples_size = params["global"]["samples_size"]
    overlap = params["sample"]["overlap"]
    dist_samples = round(samples_size * (1 - (overlap / 100)))
    min_annot_perc = get_key_def('min_annotated_percent',
                                 params['sample']['sampling_method'],
                                 None,
                                 expected_type=int)
    ignore_index = get_key_def('ignore_index', params['training'], -1)
    meta_map = get_key_def('meta_map', params['global'], default={})

    list_params = params['read_img']
    source_pan = get_key_def('pan',
                             list_params['source'],
                             default=False,
                             expected_type=bool)
    source_mul = get_key_def('mul',
                             list_params['source'],
                             default=False,
                             expected_type=bool)
    mul_band_order = get_key_def('mulband',
                                 list_params['source'],
                                 default=[],
                                 expected_type=list)
    prep_band = get_key_def('band',
                            list_params['prep'],
                            default=[],
                            expected_type=list)
    tst_set = get_key_def('benchmark',
                          list_params,
                          default=[],
                          expected_type=list)
    in_pth = get_key_def('input_file',
                         list_params,
                         default='data_file.json',
                         expected_type=str)
    sensor_lst = get_key_def(
        'sensorID',
        list_params,
        default=['GeoEye1', 'QuickBird2'
                 'WV2', 'WV3', 'WV4'],
        expected_type=list)
    month_range = get_key_def('month_range',
                              list_params,
                              default=list(range(1, 12 + 1)),
                              expected_type=list)
    root_folder = Path(
        get_key_def('root_img_folder',
                    list_params,
                    default='',
                    expected_type=str))
    gpkg_status = 'all'

    data_path = Path(params['global']['data_path'])
    Path.mkdir(data_path, exist_ok=True, parents=True)
    if not data_path.is_dir():
        raise FileNotFoundError(f'Could not locate data path {data_path}')

    # mlflow logging
    experiment_name = get_key_def('mlflow_experiment_name',
                                  params['global'],
                                  default='gdl-training',
                                  expected_type=str)
    samples_folder_name = (
        f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'
        f'_{experiment_name}')
    samples_folder = data_path.joinpath(samples_folder_name)
    if samples_folder.is_dir():
        if debug:
            # Move existing data folder with a random suffix.
            last_mod_time_suffix = datetime.fromtimestamp(
                samples_folder.stat().st_mtime).strftime('%Y%m%d-%H%M%S')
            shutil.move(
                samples_folder,
                data_path.joinpath(
                    f'{str(samples_folder)}_{last_mod_time_suffix}'))
        else:
            raise FileExistsError(
                f'Data path exists: {samples_folder}. Remove it or use a different experiment_name.'
            )

    Path.mkdir(samples_folder, exist_ok=False
               )  # TODO: what if we want to append samples to existing hdf5?
    trn_hdf5, val_hdf5, tst_hdf5 = create_files_and_datasets(
        samples_size=samples_size,
        number_of_bands=num_bands,
        meta_map=meta_map,
        samples_folder=samples_folder,
        params=params)

    class_prop = get_key_def('class_proportion',
                             params['sample']['sampling_method'],
                             None,
                             expected_type=dict)
    dontcare = get_key_def("ignore_index", params["training"], -1)
    number_samples = {'trn': 0, 'val': 0, 'tst': 0}
    number_classes = 0

    pixel_pan_counter = Counter()
    pixel_mul_counter = Counter()
    pixel_prep_counter = Counter()
    filename = samples_folder.joinpath('class_distribution.txt')

    with open(Path(in_pth), 'r') as fin:
        dict_images = json.load(fin)

        for i_dict in tqdm(dict_images['all_images'],
                           desc=f'Writing samples to {samples_folder}'):
            if i_dict['sensorID'] in sensor_lst and \
                    datetime.strptime(i_dict['date']['yyyy/mm/dd'], '%Y/%m/%d').month in month_range:

                if source_pan:
                    if not len(i_dict['pan_img']) == 0 and i_dict['gpkg']:
                        if gpkg_status == 'all':
                            if 'corr' or 'prem' in i_dict['gpkg'].keys():
                                gpkg = root_folder.joinpath(
                                    list(i_dict['gpkg'].values())[0])
                                gpkg_classes = validate_num_classes(
                                    gpkg, num_classes,
                                    'properties/Quatreclasses', dontcare,
                                    targ_ids)
                                for img_pan in i_dict['pan_img']:
                                    img_pan = root_folder.joinpath(img_pan)
                                    assert_crs_match(img_pan, gpkg)
                                    rst_pth, r_ = process_raster_img(
                                        img_pan, gpkg)
                                    np_label = process_vector_label(
                                        rst_pth, gpkg, targ_ids)
                                    if np_label is not None:
                                        if Path(gpkg).stem in tst_set:
                                            sample_type = 'tst'
                                            out_file = tst_hdf5
                                        else:
                                            sample_type = 'trn'
                                            out_file = trn_hdf5
                                        val_file = val_hdf5
                                        src = r_
                                        pan_label_gen = gen_label_samples(
                                            np_label, dist_samples,
                                            samples_size)
                                        pan_img_gen = gen_img_samples(
                                            rst_pth, samples_size,
                                            dist_samples)
                                    else:
                                        continue
                    for pan_img, pan_label in zip(pan_img_gen, pan_label_gen):
                        number_samples, number_classes, class_pixels_pan = sample_prep(
                            src,
                            pan_img,
                            pan_label[0],
                            pan_label[1],
                            gpkg_classes,
                            samples_size,
                            sample_type,
                            number_samples,
                            out_file,
                            number_classes,
                            val_percent,
                            val_file,
                            min_annot_perc,
                            class_prop=class_prop,
                            dontcare=dontcare)
                        pixel_pan_counter.update(class_pixels_pan)

                if source_mul:
                    if not len(i_dict['mul_img']) == 0 and i_dict['gpkg']:
                        band_order = reorder_bands(i_dict['mul_band'],
                                                   mul_band_order)
                        if gpkg_status == 'all':
                            if 'corr' or 'prem' in i_dict['gpkg'].keys():
                                gpkg = root_folder.joinpath(
                                    list(i_dict['gpkg'].values())[0])
                                gpkg_classes = validate_num_classes(
                                    gpkg, num_classes,
                                    'properties/Quatreclasses', dontcare,
                                    targ_ids)
                                for img_mul in i_dict['mul_img']:
                                    img_mul = root_folder.joinpath(img_mul)
                                    assert_crs_match(img_mul, gpkg)
                                    rst_pth, r_ = process_raster_img(
                                        img_mul, gpkg)
                                    np_label = process_vector_label(
                                        rst_pth, gpkg, targ_ids)
                                    if np_label is not None:
                                        if Path(gpkg).stem in tst_set:
                                            sample_type = 'tst'
                                            out_file = tst_hdf5
                                        else:
                                            sample_type = 'trn'
                                            out_file = trn_hdf5
                                        val_file = val_hdf5
                                        src = r_

                                        mul_label_gen = gen_label_samples(
                                            np_label, dist_samples,
                                            samples_size)
                                        mul_img_gen = gen_img_samples(
                                            rst_pth, samples_size,
                                            dist_samples, band_order)
                                    else:
                                        continue
                    for mul_img, mul_label in zip(mul_img_gen, mul_label_gen):
                        number_samples, number_classes, class_pixels_mul = sample_prep(
                            src,
                            mul_img,
                            mul_label[0],
                            mul_label[1],
                            gpkg_classes,
                            samples_size,
                            sample_type,
                            number_samples,
                            out_file,
                            number_classes,
                            val_percent,
                            val_file,
                            min_annot_perc,
                            class_prop=class_prop,
                            dontcare=dontcare)
                        pixel_mul_counter.update(class_pixels_mul)

                if prep_band:
                    bands_gen_list = []
                    if set(prep_band).issubset({'R', 'G', 'B', 'N'}):
                        for ib in prep_band:
                            if i_dict[f'{ib}_band'] and i_dict['gpkg']:
                                i_dict[f'{ib}_band'] = root_folder.joinpath(
                                    i_dict[f'{ib}_band'])
                                if gpkg_status == 'all':
                                    if 'corr' or 'prem' in i_dict['gpkg'].keys(
                                    ):
                                        gpkg = root_folder.joinpath(
                                            list(i_dict['gpkg'].values())[0])
                                        gpkg_classes = validate_num_classes(
                                            gpkg, num_classes,
                                            'properties/Quatreclasses',
                                            dontcare, targ_ids)
                                        assert_crs_match(
                                            i_dict[f'{ib}_band'], gpkg)
                                        rst_pth, r_ = process_raster_img(
                                            i_dict[f'{ib}_band'], gpkg)
                                        np_label = process_vector_label(
                                            rst_pth, gpkg, targ_ids)
                                        prep_img_gen = gen_img_samples(
                                            rst_pth, samples_size,
                                            dist_samples)
                                        bands_gen_list.append(prep_img_gen)

                    if np_label is not None:
                        if Path(gpkg).stem in tst_set:
                            sample_type = 'tst'
                            out_file = tst_hdf5
                        else:
                            sample_type = 'trn'
                            out_file = trn_hdf5
                        val_file = val_hdf5
                        src = r_
                        prep_label_gen = gen_label_samples(
                            np_label, dist_samples, samples_size)
                        if len(prep_band) and len(bands_gen_list) == 1:
                            for b1, prep_label in zip(bands_gen_list[0],
                                                      prep_label_gen):
                                prep_img = b1
                                number_samples, number_classes, class_pixels_prep = sample_prep(
                                    src,
                                    prep_img,
                                    prep_label[0],
                                    prep_label[1],
                                    gpkg_classes,
                                    samples_size,
                                    sample_type,
                                    number_samples,
                                    out_file,
                                    number_classes,
                                    val_percent,
                                    val_file,
                                    min_annot_perc,
                                    class_prop=class_prop,
                                    dontcare=dontcare)
                                pixel_prep_counter.update(class_pixels_prep)

                        elif len(prep_band) and len(bands_gen_list) == 2:
                            for b1, b2, prep_label in zip(
                                    *bands_gen_list, prep_label_gen):
                                prep_img = np.dstack(np.array([b1, b2]))
                                number_samples, number_classes, class_pixels_prep = sample_prep(
                                    src,
                                    prep_img,
                                    prep_label[0],
                                    prep_label[1],
                                    gpkg_classes,
                                    samples_size,
                                    sample_type,
                                    number_samples,
                                    out_file,
                                    number_classes,
                                    val_percent,
                                    val_file,
                                    min_annot_perc,
                                    class_prop=class_prop,
                                    dontcare=dontcare)
                                pixel_prep_counter.update(class_pixels_prep)

                        elif len(prep_band) and len(bands_gen_list) == 3:
                            for b1, b2, b3, prep_label in zip(
                                    *bands_gen_list, prep_label_gen):
                                prep_img = np.dstack(np.array([b1, b2, b3]))
                                number_samples, number_classes, class_pixels_prep = sample_prep(
                                    src,
                                    prep_img,
                                    prep_label[0],
                                    prep_label[1],
                                    gpkg_classes,
                                    samples_size,
                                    sample_type,
                                    number_samples,
                                    out_file,
                                    number_classes,
                                    val_percent,
                                    val_file,
                                    min_annot_perc,
                                    class_prop=class_prop,
                                    dontcare=dontcare)
                                pixel_prep_counter.update(class_pixels_prep)

                        elif len(prep_band) and len(bands_gen_list) == 4:
                            for b1, b2, b3, b4, prep_label in zip(
                                    *bands_gen_list, prep_label_gen):
                                prep_img = np.dstack(np.array([b1, b2, b3,
                                                               b4]))
                                number_samples, number_classes, class_pixels_prep = sample_prep(
                                    src,
                                    prep_img,
                                    prep_label[0],
                                    prep_label[1],
                                    gpkg_classes,
                                    samples_size,
                                    sample_type,
                                    number_samples,
                                    out_file,
                                    number_classes,
                                    val_percent,
                                    val_file,
                                    min_annot_perc,
                                    class_prop=class_prop,
                                    dontcare=dontcare)
                                pixel_prep_counter.update(class_pixels_prep)
                        else:
                            continue
            else:
                continue
    trn_hdf5.close()
    val_hdf5.close()
    tst_hdf5.close()

    class_pixel_ratio(pixel_pan_counter, 'pan_source', filename)
    class_pixel_ratio(pixel_mul_counter, 'mul_source', filename)
    class_pixel_ratio(pixel_prep_counter, 'prep_source', filename)
    print("Number of samples created: ", number_samples, number_classes)
Beispiel #7
0
def main(cfg: DictConfig) -> None:
    """
    Function that create training, validation and testing datasets preparation.

    1. Read csv file and validate existence of all input files and GeoPackages.
    2. Do the following verifications:
        1. Assert number of bands found in raster is equal to desired number
           of bands.
        2. Check that `num_classes` is equal to number of classes detected in
           the specified attribute for each GeoPackage.
           Warning: this validation will not succeed if a Geopackage
                    contains only a subset of `num_classes` (e.g. 3 of 4).
        3. Assert Coordinate reference system between raster and gpkg match.
    3. Read csv file and for each line in the file, do the following:
        1. Read input image as array with utils.readers.image_reader_as_array().
            - If gpkg's extent is smaller than raster's extent,
              raster is clipped to gpkg's extent.
            - If gpkg's extent is bigger than raster's extent,
              gpkg is clipped to raster's extent.
        2. Convert GeoPackage vector information into the "label" raster with
           utils.utils.vector_to_raster(). The pixel value is determined by the
           attribute in the csv file.
        3. Create a new raster called "label" with the same properties as the
           input image.
        4. Read metadata and add to input as new bands (*more details to come*).
        5. Crop the arrays in smaller samples of the size `samples_size` of
           `your_conf.yaml`. Visual representation of this is provided at
            https://medium.com/the-downlinq/broad-area-satellite-imagery-semantic-segmentation-basiss-4a7ea2c8466f
        6. Write samples from input image and label into the "val", "trn" or
           "tst" hdf5 file, depending on the value contained in the csv file.
            Refer to samples_preparation().

    -------
    :param cfg: (dict) Parameters found in the yaml config file.
    """
    try:
        import boto3
    except ModuleNotFoundError:
        logging.warning(
            "\nThe boto3 library couldn't be imported. Ignore if not using AWS s3 buckets",
            ImportWarning)
        pass

    # PARAMETERS
    num_classes = len(cfg.dataset.classes_dict.keys())
    num_bands = len(cfg.dataset.modalities)
    modalities = read_modalities(
        cfg.dataset.modalities
    )  # TODO add the Victor module to manage the modalities
    debug = cfg.debug
    task = cfg.task.task_name

    # RAW DATA PARAMETERS
    # Data folder
    try:
        # check if the folder exist
        my_data_path = Path(cfg.dataset.raw_data_dir).resolve(strict=True)
        logging.info("\nImage directory used '{}'".format(my_data_path))
        data_path = Path(my_data_path)
    except FileNotFoundError:
        raise logging.critical(
            "\nImage directory '{}' doesn't exist, please change the path".
            format(cfg.dataset.raw_data_dir))
    # CSV file
    try:
        my_csv_path = Path(cfg.dataset.raw_data_csv).resolve(strict=True)
        # path.exists(cfg.dataset.raw_data_csv)
        logging.info("\nImage csv: '{}'".format(my_csv_path))
        csv_file = my_csv_path
    except FileNotFoundError:
        raise logging.critical(
            "\nImage csv '{}' doesn't exist, please change the path".format(
                cfg.dataset.raw_data_csv))
    # HDF5 data
    try:
        my_hdf5_path = Path(str(
            cfg.dataset.sample_data_dir)).resolve(strict=True)
        logging.info("\nThe HDF5 directory used '{}'".format(my_hdf5_path))
        Path.mkdir(Path(my_hdf5_path), exist_ok=True, parents=True)
    except FileNotFoundError:
        logging.info(
            "\nThe HDF5 directory '{}' doesn't exist, please change the path.".
            format(cfg.dataset.sample_data_dir) +
            "\nFor now the HDF5 directory use will be change for '{}'".format(
                data_path))
        cfg.general.sample_data_dir = str(data_path)

    # SAMPLE PARAMETERS
    samples_size = get_key_def('input_dim',
                               cfg['dataset'],
                               default=256,
                               expected_type=int)
    overlap = get_key_def('overlap', cfg['dataset'], default=0)
    min_annot_perc = get_key_def('min_annot_perc', cfg['dataset'], default=0)
    val_percent = get_key_def('train_val_percent', cfg['dataset'],
                              default=0.3)['val'] * 100
    samples_folder_name = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}' \
                          f'_{num_bands}bands_{cfg.general.project_name}'
    samples_dir = data_path.joinpath(samples_folder_name)
    if samples_dir.is_dir():
        if debug:
            # Move existing data folder with a random suffix.
            last_mod_time_suffix = datetime.fromtimestamp(
                samples_dir.stat().st_mtime).strftime('%Y%m%d-%H%M%S')
            shutil.move(
                samples_dir,
                data_path.joinpath(
                    f'{str(samples_dir)}_{last_mod_time_suffix}'))
        else:
            logging.critical(
                f'Data path exists: {samples_dir}. Remove it or use a different experiment_name.'
            )
            raise FileExistsError()
    Path.mkdir(samples_dir, exist_ok=False
               )  # TODO: what if we want to append samples to existing hdf5?

    # LOGGING PARAMETERS  TODO see logging yaml
    experiment_name = cfg.general.project_name
    # mlflow_uri = get_key_def('mlflow_uri', params['global'], default="./mlruns")

    # OTHER PARAMETERS
    metadata = None
    meta_map = {}  # TODO get_key_def('meta_map', params['global'], default={})
    # TODO class_prop get_key_def('class_proportion', params['sample']['sampling_method'], None, expected_type=dict)
    class_prop = None
    mask_reference = False  # TODO get_key_def('mask_reference', params['sample'], default=False, expected_type=bool)
    # set dontcare (aka ignore_index) value
    dontcare = cfg.dataset.ignore_index if cfg.dataset.ignore_index is not None else -1
    if dontcare == 0:
        logging.warning(
            "\nThe 'dontcare' value (or 'ignore_index') used in the loss function cannot be zero."
            " All valid class indices should be consecutive, and start at 0. The 'dontcare' value"
            " will be remapped to -1 while loading the dataset, and inside the config from now on."
        )
        dontcare = -1
    # Assert that all items in target_ids are integers (ex.: single-class samples from multi-class label)
    targ_ids = None  # TODO get_key_def('target_ids', params['sample'], None, expected_type=List)
    if targ_ids is list:
        for item in targ_ids:
            if not isinstance(item, int):
                raise logging.critical(
                    ValueError(
                        f'\nTarget id "{item}" in target_ids is {type(item)}, expected int.'
                    ))

    # OPTIONAL
    use_stratification = cfg.dataset.use_stratification if cfg.dataset.use_stratification is not None else False
    if use_stratification:
        stratd = {
            'trn': {
                'total_pixels': 0,
                'total_counts': {},
                'total_props': {}
            },
            'val': {
                'total_pixels': 0,
                'total_counts': {},
                'total_props': {}
            },
            'strat_factor': cfg['dataset']['use_stratification']
        }
    else:
        stratd = None

    # ADD GIT HASH FROM CURRENT COMMIT TO PARAMETERS (if available and parameters will be saved to hdf5s).
    with open_dict(cfg):
        cfg.general.git_hash = get_git_hash()

    # AWS TODO
    bucket_name = cfg.AWS.bucket_name
    if bucket_name:
        final_samples_folder = None
        bucket_name = cfg.AWS.bucket_name
        bucket_file_cache = []
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(bucket_name)
        bucket.download_file(csv_file, 'samples_prep.csv')
        list_data_prep = read_csv('samples_prep.csv', data_path)
    else:
        list_data_prep = read_csv(csv_file, data_path)

    # IF DEBUG IS ACTIVATE
    if debug:
        logging.warning(
            f'\nDebug mode activated. Some debug features may mobilize extra disk space and cause delays in execution.'
        )

    # VALIDATION: (1) Assert num_classes parameters == num actual classes in gpkg and (2) check CRS match (tif and gpkg)
    valid_gpkg_set = set()
    for info in tqdm(list_data_prep, position=0):
        validate_raster(info['tif'], num_bands, meta_map)
        if info['gpkg'] not in valid_gpkg_set:
            gpkg_classes = validate_num_classes(
                info['gpkg'],
                num_classes,
                info['attribute_name'],
                dontcare,
                target_ids=targ_ids,
            )
            assert_crs_match(info['tif'], info['gpkg'])
            valid_gpkg_set.add(info['gpkg'])

    if debug:
        # VALIDATION (debug only): Checking validity of features in vector files
        for info in tqdm(
                list_data_prep,
                position=0,
                desc=f"Checking validity of features in vector files"):
            # TODO: make unit to test this with invalid features.
            invalid_features = validate_features_from_gpkg(
                info['gpkg'], info['attribute_name'])
            if invalid_features:
                logging.critical(
                    f"{info['gpkg']}: Invalid geometry object(s) '{invalid_features}'"
                )

    number_samples = {'trn': 0, 'val': 0, 'tst': 0}
    number_classes = 0

    # with open_dict(cfg):
    #     print(cfg)
    trn_hdf5, val_hdf5, tst_hdf5 = create_files_and_datasets(
        samples_size=samples_size,
        number_of_bands=num_bands,
        meta_map=meta_map,
        samples_folder=samples_dir,
        cfg=cfg)

    # creates pixel_classes dict and keys
    pixel_classes = {key: 0 for key in gpkg_classes}
    background_val = 0
    pixel_classes[background_val] = 0
    class_prop = validate_class_prop_dict(pixel_classes, class_prop)
    pixel_classes[dontcare] = 0

    # For each row in csv: (1) burn vector file to raster, (2) read input raster image, (3) prepare samples
    logging.info(
        f"\nPreparing samples \n  Samples_size: {samples_size} \n  Overlap: {overlap} "
        f"\n  Validation set: {val_percent} % of created training samples")
    for info in tqdm(list_data_prep, position=0, leave=False):
        try:
            if bucket_name:
                bucket.download_file(info['tif'],
                                     "Images/" + info['tif'].split('/')[-1])
                info['tif'] = "Images/" + info['tif'].split('/')[-1]
                if info['gpkg'] not in bucket_file_cache:
                    bucket_file_cache.append(info['gpkg'])
                    bucket.download_file(info['gpkg'],
                                         info['gpkg'].split('/')[-1])
                info['gpkg'] = info['gpkg'].split('/')[-1]
                if info['meta']:
                    if info['meta'] not in bucket_file_cache:
                        bucket_file_cache.append(info['meta'])
                        bucket.download_file(info['meta'],
                                             info['meta'].split('/')[-1])
                    info['meta'] = info['meta'].split('/')[-1]

            logging.info(f"\nReading as array: {info['tif']}")
            with rasterio.open(info['tif'], 'r') as raster:
                # 1. Read the input raster image
                np_input_image, raster, dataset_nodata = image_reader_as_array(
                    input_image=raster,
                    clip_gpkg=info['gpkg'],
                    aux_vector_file=get_key_def('aux_vector_file',
                                                cfg['dataset'], None),
                    aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                  cfg['dataset'], None),
                    aux_vector_ids=get_key_def('aux_vector_ids',
                                               cfg['dataset'], None),
                    aux_vector_dist_maps=get_key_def('aux_vector_dist_maps',
                                                     cfg['dataset'], True),
                    aux_vector_dist_log=get_key_def('aux_vector_dist_log',
                                                    cfg['dataset'], True),
                    aux_vector_scale=get_key_def('aux_vector_scale',
                                                 cfg['dataset'], None))

                # 2. Burn vector file in a raster file
                logging.info(
                    f"\nRasterizing vector file (attribute: {info['attribute_name']}): {info['gpkg']}"
                )
                np_label_raster = vector_to_raster(
                    vector_file=info['gpkg'],
                    input_image=raster,
                    out_shape=np_input_image.shape[:2],
                    attribute_name=info['attribute_name'],
                    fill=background_val,
                    target_ids=targ_ids
                )  # background value in rasterized vector.

                if dataset_nodata is not None:
                    # 3. Set ignore_index value in label array where nodata in raster (only if nodata across all bands)
                    np_label_raster[dataset_nodata] = dontcare

            if debug:
                out_meta = raster.meta.copy()
                np_image_debug = np_input_image.transpose(2, 0, 1).astype(
                    out_meta['dtype'])
                out_meta.update({
                    "driver": "GTiff",
                    "height": np_image_debug.shape[1],
                    "width": np_image_debug.shape[2]
                })
                out_tif = samples_dir / f"{Path(info['tif']).stem}_clipped.tif"
                logging.debug(f"Writing clipped raster to {out_tif}")
                with rasterio.open(out_tif, "w", **out_meta) as dest:
                    dest.write(np_image_debug)

                out_meta = raster.meta.copy()
                np_label_debug = np.expand_dims(
                    np_label_raster,
                    axis=2).transpose(2, 0, 1).astype(out_meta['dtype'])
                out_meta.update({
                    "driver": "GTiff",
                    "height": np_label_debug.shape[1],
                    "width": np_label_debug.shape[2],
                    'count': 1
                })
                out_tif = samples_dir / f"{Path(info['gpkg']).stem}_clipped.tif"
                logging.debug(f"\nWriting final rasterized gpkg to {out_tif}")
                with rasterio.open(out_tif, "w", **out_meta) as dest:
                    dest.write(np_label_debug)

            # Mask the zeros from input image into label raster.
            if mask_reference:
                np_label_raster = mask_image(np_input_image, np_label_raster)

            if info['dataset'] == 'trn':
                out_file = trn_hdf5
            elif info['dataset'] == 'tst':
                out_file = tst_hdf5
            else:
                raise logging.critical(
                    ValueError(
                        f"\nDataset value must be trn or tst. Provided value is {info['dataset']}"
                    ))
            val_file = val_hdf5

            metadata = add_metadata_from_raster_to_sample(
                sat_img_arr=np_input_image,
                raster_handle=raster,
                meta_map=meta_map,
                raster_info=info)
            # Save label's per class pixel count to image metadata
            metadata['source_label_bincount'] = {
                class_num: count
                for class_num, count in enumerate(
                    np.bincount(np_label_raster.clip(min=0).flatten()))
                if count > 0
            }  # TODO: add this to add_metadata_from[...] function?

            np_label_raster = np.reshape(
                np_label_raster,
                (np_label_raster.shape[0], np_label_raster.shape[1], 1))
            # 3. Prepare samples!
            number_samples, number_classes = samples_preparation(
                in_img_array=np_input_image,
                label_array=np_label_raster,
                sample_size=samples_size,
                overlap=overlap,
                samples_count=number_samples,
                num_classes=number_classes,
                samples_file=out_file,
                val_percent=val_percent,
                val_sample_file=val_file,
                dataset=info['dataset'],
                pixel_classes=pixel_classes,
                dontcare=dontcare,
                image_metadata=metadata,
                min_annot_perc=min_annot_perc,
                class_prop=class_prop,
                stratd=stratd)

            # logging.info(f'\nNumber of samples={number_samples}')
            out_file.flush()
        except OSError:
            logging.exception(
                f'\nAn error occurred while preparing samples with "{Path(info["tif"]).stem}" (tiff) and '
                f'{Path(info["gpkg"]).stem} (gpkg).')
            continue

    trn_hdf5.close()
    val_hdf5.close()
    tst_hdf5.close()

    pixel_total = 0
    # adds up the number of pixels for each class in pixel_classes dict
    for i in pixel_classes:
        pixel_total += pixel_classes[i]
    # calculate the proportion of pixels of each class for the samples created
    pixel_classes_dict = {}
    for i in pixel_classes:
        # prop = round((pixel_classes[i] / pixel_total) * 100, 1) if pixel_total > 0 else 0
        pixel_classes_dict[i] = round((pixel_classes[i] / pixel_total) *
                                      100, 1) if pixel_total > 0 else 0
    # prints the proportion of pixels of each class for the samples created
    msg_pixel_classes = "\n".join("Pixels from class {}: {}%".format(k, v)
                                  for k, v in pixel_classes_dict.items())
    logging.info("\n" + msg_pixel_classes)

    logging.info(f"\nNumber of samples created: {number_samples}")

    if bucket_name and final_samples_folder:  # FIXME: final_samples_folder always None in current implementation
        logging.info('\nTransfering Samples to the bucket')
        bucket.upload_file(samples_dir + "/trn_samples.hdf5",
                           final_samples_folder + '/trn_samples.hdf5')
        bucket.upload_file(samples_dir + "/val_samples.hdf5",
                           final_samples_folder + '/val_samples.hdf5')
        bucket.upload_file(samples_dir + "/tst_samples.hdf5",
                           final_samples_folder + '/tst_samples.hdf5')