Exemplo n.º 1
0
def main(params):
    """
    Training and validation datasets preparation.

    Process
    -------
    1. Read csv file and validate existence of all input files and GeoPackages.

    2. Do the following verifications:
        1. Assert number of bands found in raster is equal to desired number
           of bands.
        2. Check that `num_classes` is equal to number of classes detected in
           the specified attribute for each GeoPackage.
           Warning: this validation will not succeed if a Geopackage
                    contains only a subset of `num_classes` (e.g. 3 of 4).
        3. Assert Coordinate reference system between raster and gpkg match.

    3. Read csv file and for each line in the file, do the following:
        1. Read input image as array with utils.readers.image_reader_as_array().
            - If gpkg's extent is smaller than raster's extent,
              raster is clipped to gpkg's extent.
            - If gpkg's extent is bigger than raster's extent,
              gpkg is clipped to raster's extent.
        2. Convert GeoPackage vector information into the "label" raster with
           utils.utils.vector_to_raster(). The pixel value is determined by the
           attribute in the csv file.
        3. Create a new raster called "label" with the same properties as the
           input image.
        4. Read metadata and add to input as new bands (*more details to come*).
        5. Crop the arrays in smaller samples of the size `samples_size` of
           `your_conf.yaml`. Visual representation of this is provided at
            https://medium.com/the-downlinq/broad-area-satellite-imagery-semantic-segmentation-basiss-4a7ea2c8466f
        6. Write samples from input image and label into the "val", "trn" or
           "tst" hdf5 file, depending on the value contained in the csv file.
            Refer to samples_preparation().

    -------
    :param params: (dict) Parameters found in the yaml config file.
    """
    params['global']['git_hash'] = get_git_hash()
    now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    bucket_file_cache = []

    assert params['global'][
        'task'] == 'segmentation', f"images_to_samples.py isn't necessary when performing classification tasks"

    # SET BASIC VARIABLES AND PATHS. CREATE OUTPUT FOLDERS.
    bucket_name = get_key_def('bucket_name', params['global'])
    data_path = Path(params['global']['data_path'])
    Path.mkdir(data_path, exist_ok=True, parents=True)
    csv_file = params['sample']['prep_csv_file']
    val_percent = params['sample']['val_percent']
    samples_size = params["global"]["samples_size"]
    overlap = params["sample"]["overlap"]
    min_annot_perc = get_key_def('min_annotated_percent',
                                 params['sample']['sampling_method'],
                                 None,
                                 expected_type=int)
    num_bands = params['global']['number_of_bands']
    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(f'Debug mode activate. Execution may take longer...')

    final_samples_folder = None

    sample_path_name = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'

    # AWS
    if bucket_name:
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(bucket_name)
        bucket.download_file(csv_file, 'samples_prep.csv')
        list_data_prep = read_csv('samples_prep.csv')
        if data_path:
            final_samples_folder = data_path.joinpath("samples")
        else:
            final_samples_folder = "samples"
        samples_folder = sample_path_name

    else:
        list_data_prep = read_csv(csv_file)
        samples_folder = data_path.joinpath(sample_path_name)

    if samples_folder.is_dir():
        warnings.warn(
            f'Data path exists: {samples_folder}. Suffix will be added to directory name.'
        )
        samples_folder = Path(str(samples_folder) + '_' + now)
    else:
        tqdm.write(f'Writing samples to {samples_folder}')
    Path.mkdir(samples_folder, exist_ok=False
               )  # TODO: what if we want to append samples to existing hdf5?
    tqdm.write(f'Samples will be written to {samples_folder}\n\n')

    tqdm.write(f'\nSuccessfully read csv file: {Path(csv_file).stem}\n'
               f'Number of rows: {len(list_data_prep)}\n'
               f'Copying first entry:\n{list_data_prep[0]}\n')
    ignore_index = get_key_def('ignore_index', params['training'], -1)
    meta_map, metadata = get_key_def("meta_map", params["global"], {}), None

    # VALIDATION: (1) Assert num_classes parameters == num actual classes in gpkg and (2) check CRS match (tif and gpkg)
    valid_gpkg_set = set()
    for info in tqdm(list_data_prep, position=0):
        assert_num_bands(info['tif'], num_bands, meta_map)
        if info['gpkg'] not in valid_gpkg_set:
            gpkg_classes = validate_num_classes(
                info['gpkg'], params['global']['num_classes'],
                info['attribute_name'], ignore_index)
            assert_crs_match(info['tif'], info['gpkg'])
            valid_gpkg_set.add(info['gpkg'])

    if debug:
        # VALIDATION (debug only): Checking validity of features in vector files
        for info in tqdm(
                list_data_prep,
                position=0,
                desc=f"Checking validity of features in vector files"):
            invalid_features = validate_features_from_gpkg(
                info['gpkg'], info['attribute_name']
            )  # TODO: test this with invalid features.
            assert not invalid_features, f"{info['gpkg']}: Invalid geometry object(s) '{invalid_features}'"

    number_samples = {'trn': 0, 'val': 0, 'tst': 0}
    number_classes = 0

    class_prop = get_key_def('class_proportion',
                             params['sample']['sampling_method'],
                             None,
                             expected_type=dict)

    trn_hdf5, val_hdf5, tst_hdf5 = create_files_and_datasets(
        params, samples_folder)

    # Set dontcare (aka ignore_index) value
    dontcare = get_key_def(
        "ignore_index", params["training"],
        -1)  # TODO: deduplicate with train_segmentation, l300
    if dontcare == 0:
        warnings.warn(
            "The 'dontcare' value (or 'ignore_index') used in the loss function cannot be zero;"
            " all valid class indices should be consecutive, and start at 0. The 'dontcare' value"
            " will be remapped to -1 while loading the dataset, and inside the config from now on."
        )
        params["training"]["ignore_index"] = -1

    # creates pixel_classes dict and keys
    pixel_classes = {key: 0 for key in gpkg_classes}
    background_val = 0
    pixel_classes[background_val] = 0
    class_prop = validate_class_prop_dict(pixel_classes, class_prop)
    pixel_classes[dontcare] = 0

    # For each row in csv: (1) burn vector file to raster, (2) read input raster image, (3) prepare samples
    with tqdm(list_data_prep,
              position=0,
              leave=False,
              desc=f'Preparing samples') as _tqdm:
        for info in _tqdm:
            _tqdm.set_postfix(
                OrderedDict(tif=f'{Path(info["tif"]).stem}',
                            sample_size=params['global']['samples_size']))
            try:
                if bucket_name:
                    bucket.download_file(
                        info['tif'], "Images/" + info['tif'].split('/')[-1])
                    info['tif'] = "Images/" + info['tif'].split('/')[-1]
                    if info['gpkg'] not in bucket_file_cache:
                        bucket_file_cache.append(info['gpkg'])
                        bucket.download_file(info['gpkg'],
                                             info['gpkg'].split('/')[-1])
                    info['gpkg'] = info['gpkg'].split('/')[-1]
                    if info['meta']:
                        if info['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(info['meta'])
                            bucket.download_file(info['meta'],
                                                 info['meta'].split('/')[-1])
                        info['meta'] = info['meta'].split('/')[-1]

                with rasterio.open(info['tif'], 'r') as raster:
                    # 1. Read the input raster image
                    np_input_image, raster, dataset_nodata = image_reader_as_array(
                        input_image=raster,
                        clip_gpkg=info['gpkg'],
                        aux_vector_file=get_key_def('aux_vector_file',
                                                    params['global'], None),
                        aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                      params['global'], None),
                        aux_vector_ids=get_key_def('aux_vector_ids',
                                                   params['global'], None),
                        aux_vector_dist_maps=get_key_def(
                            'aux_vector_dist_maps', params['global'], True),
                        aux_vector_dist_log=get_key_def(
                            'aux_vector_dist_log', params['global'], True),
                        aux_vector_scale=get_key_def('aux_vector_scale',
                                                     params['global'], None))

                    # 2. Burn vector file in a raster file
                    np_label_raster = vector_to_raster(
                        vector_file=info['gpkg'],
                        input_image=raster,
                        out_shape=np_input_image.shape[:2],
                        attribute_name=info['attribute_name'],
                        fill=background_val
                    )  # background value in rasterized vector.

                    if dataset_nodata is not None:
                        # 3. Set ignore_index value in label array where nodata in raster (only if nodata across all bands)
                        np_label_raster[dataset_nodata] = dontcare

                if debug:
                    out_meta = raster.meta.copy()
                    np_image_debug = np_input_image.transpose(2, 0, 1).astype(
                        out_meta['dtype'])
                    out_meta.update({
                        "driver": "GTiff",
                        "height": np_image_debug.shape[1],
                        "width": np_image_debug.shape[2]
                    })
                    out_tif = samples_folder / f"np_input_image_{_tqdm.n}.tif"
                    print(f"DEBUG: writing clipped raster to {out_tif}")
                    with rasterio.open(out_tif, "w", **out_meta) as dest:
                        dest.write(np_image_debug)

                    out_meta = raster.meta.copy()
                    np_label_debug = np.expand_dims(
                        np_label_raster,
                        axis=2).transpose(2, 0, 1).astype(out_meta['dtype'])
                    out_meta.update({
                        "driver": "GTiff",
                        "height": np_label_debug.shape[1],
                        "width": np_label_debug.shape[2],
                        'count': 1
                    })
                    out_tif = samples_folder / f"np_label_rasterized_{_tqdm.n}.tif"
                    print(f"DEBUG: writing final rasterized gpkg to {out_tif}")
                    with rasterio.open(out_tif, "w", **out_meta) as dest:
                        dest.write(np_label_debug)

                # Mask the zeros from input image into label raster.
                if params['sample']['mask_reference']:
                    np_label_raster = mask_image(np_input_image,
                                                 np_label_raster)

                if info['dataset'] == 'trn':
                    out_file = trn_hdf5
                elif info['dataset'] == 'tst':
                    out_file = tst_hdf5
                else:
                    raise ValueError(
                        f"Dataset value must be trn or tst. Provided value is {info['dataset']}"
                    )
                val_file = val_hdf5

                metadata = add_metadata_from_raster_to_sample(
                    sat_img_arr=np_input_image,
                    raster_handle=raster,
                    meta_map=meta_map,
                    raster_info=info)
                # Save label's per class pixel count to image metadata
                metadata['source_label_bincount'] = {
                    class_num: count
                    for class_num, count in enumerate(
                        np.bincount(np_label_raster.clip(min=0).flatten()))
                    if count > 0
                }  # TODO: add this to add_metadata_from[...] function?

                np_label_raster = np.reshape(
                    np_label_raster,
                    (np_label_raster.shape[0], np_label_raster.shape[1], 1))
                # 3. Prepare samples!
                number_samples, number_classes = samples_preparation(
                    in_img_array=np_input_image,
                    label_array=np_label_raster,
                    sample_size=samples_size,
                    overlap=overlap,
                    samples_count=number_samples,
                    num_classes=number_classes,
                    samples_file=out_file,
                    val_percent=val_percent,
                    val_sample_file=val_file,
                    dataset=info['dataset'],
                    pixel_classes=pixel_classes,
                    image_metadata=metadata,
                    dontcare=dontcare,
                    min_annot_perc=min_annot_perc,
                    class_prop=class_prop)

                _tqdm.set_postfix(OrderedDict(number_samples=number_samples))
                out_file.flush()
            except OSError as e:
                warnings.warn(
                    f'An error occurred while preparing samples with "{Path(info["tif"]).stem}" (tiff) and '
                    f'{Path(info["gpkg"]).stem} (gpkg). Error: "{e}"')
                continue

    trn_hdf5.close()
    val_hdf5.close()
    tst_hdf5.close()

    pixel_total = 0
    # adds up the number of pixels for each class in pixel_classes dict
    for i in pixel_classes:
        pixel_total += pixel_classes[i]

    # prints the proportion of pixels of each class for the samples created
    for i in pixel_classes:
        prop = round((pixel_classes[i] / pixel_total) *
                     100, 1) if pixel_total > 0 else 0
        print('Pixels from class', i, ':', prop, '%')

    print("Number of samples created: ", number_samples)

    if bucket_name and final_samples_folder:
        print('Transfering Samples to the bucket')
        bucket.upload_file(samples_folder + "/trn_samples.hdf5",
                           final_samples_folder + '/trn_samples.hdf5')
        bucket.upload_file(samples_folder + "/val_samples.hdf5",
                           final_samples_folder + '/val_samples.hdf5')
        bucket.upload_file(samples_folder + "/tst_samples.hdf5",
                           final_samples_folder + '/tst_samples.hdf5')

    print("End of process")
Exemplo n.º 2
0
def main(params: dict):
    """
    Identify the class to which each image belongs.
    :param params: (dict) Parameters found in the yaml config file.

    """
    # SET BASIC VARIABLES AND PATHS
    since = time.time()

    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(f'Debug mode activated. Some debug features may mobilize extra disk space and cause delays in execution.')

    num_classes = params['global']['num_classes']
    task = params['global']['task']
    num_classes_corrected = add_background_to_num_class(task, num_classes)

    chunk_size = get_key_def('chunk_size', params['inference'], 512)
    overlap = get_key_def('overlap', params['inference'], 10)
    nbr_pix_overlap = int(math.floor(overlap / 100 * chunk_size))
    num_bands = params['global']['number_of_bands']

    img_dir_or_csv = params['inference']['img_dir_or_csv_file']

    default_working_folder = Path(params['inference']['state_dict_path']).parent.joinpath(f'inference_{num_bands}bands')
    working_folder = get_key_def('working_folder', params['inference'], None)
    if working_folder:  # TODO: July 2020: deprecation started. Remove custom working_folder parameter as of Sept 2020?
        working_folder = Path(working_folder)
        warnings.warn(f"Deprecated parameter. Remove it in your future yamls as this folder is now created "
                      f"automatically in a logical path, "
                      f"i.e. [state_dict_path from inference section in yaml]/inference_[num_bands]bands")
    else:
        working_folder = default_working_folder
    Path.mkdir(working_folder, exist_ok=True)
    print(f'Inferences will be saved to: {working_folder}\n\n')

    bucket = None
    bucket_file_cache = []
    bucket_name = get_key_def('bucket_name', params['global'])

    # CONFIGURE MODEL
    model, state_dict_path, model_name = net(params, num_channels=num_classes_corrected, inference=True)

    num_devices = params['global']['num_gpus'] if params['global']['num_gpus'] else 0
    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    lst_device_ids = get_device_ids(num_devices) if torch.cuda.is_available() else []
    device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.is_available() and lst_device_ids else 'cpu')

    if lst_device_ids:
        print(f"Number of cuda devices requested: {num_devices}. Cuda devices available: {lst_device_ids}. Using {lst_device_ids[0]}\n\n")
    else:
        warnings.warn(f"No Cuda device available. This process will only run on CPU")

    try:
        model.to(device)
    except RuntimeError:
        print(f"Unable to use device. Trying device 0")
        device = torch.device(f'cuda:0' if torch.cuda.is_available() and lst_device_ids else 'cpu')
        model.to(device)

    # CREATE LIST OF INPUT IMAGES FOR INFERENCE
    list_img = list_input_images(img_dir_or_csv, bucket_name, glob_patterns=["*.tif", "*.TIF"])

    if task == 'classification':
        classifier(params, list_img, model, device, working_folder)  # FIXME: why don't we load from checkpoint in classification?

    elif task == 'segmentation':
        if bucket:
            bucket.download_file(state_dict_path, "saved_model.pth.tar")  # TODO: is this still valid?
            model, _ = load_from_checkpoint("saved_model.pth.tar", model)
        else:
            model, _ = load_from_checkpoint(state_dict_path, model)

        ignore_index = get_key_def('ignore_index', params['training'], -1)
        meta_map, yaml_metadata = get_key_def("meta_map", params["global"], {}), None

        # LOOP THROUGH LIST OF INPUT IMAGES
        with tqdm(list_img, desc='image list', position=0) as _tqdm:
            for info in _tqdm:
                img_name = Path(info['tif']).name
                if bucket:
                    local_img = f"Images/{img_name}"
                    bucket.download_file(info['tif'], local_img)
                    inference_image = f"Classified_Images/{img_name.split('.')[0]}_inference.tif"
                    if info['meta']:
                        if info['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(info['meta'])
                            bucket.download_file(info['meta'], info['meta'].split('/')[-1])
                        info['meta'] = info['meta'].split('/')[-1]
                else:  # FIXME: else statement should support img['meta'] integration as well.
                    local_img = Path(info['tif'])
                    inference_image = working_folder.joinpath(f"{img_name.split('.')[0]}_inference.tif")

                assert local_img.is_file(), f"Could not open raster file at {local_img}"

                # Empty sample as dictionary
                inf_sample = {'sat_img': None, 'metadata': None}

                with rasterio.open(local_img, 'r') as raster_handle:
                    inf_sample['sat_img'], raster_handle_updated, dataset_nodata = image_reader_as_array(
                                    input_image=raster_handle,
                                    aux_vector_file=get_key_def('aux_vector_file', params['global'], None),
                                    aux_vector_attrib=get_key_def('aux_vector_attrib', params['global'], None),
                                    aux_vector_ids=get_key_def('aux_vector_ids', params['global'], None),
                                    aux_vector_dist_maps=get_key_def('aux_vector_dist_maps', params['global'], True),
                                    aux_vector_scale=get_key_def('aux_vector_scale', params['global'], None))

                inf_sample['metadata'] = add_metadata_from_raster_to_sample(sat_img_arr=inf_sample['sat_img'],
                                                                            raster_handle=raster_handle_updated,
                                                                            meta_map=meta_map,
                                                                            raster_info=info)

                _tqdm.set_postfix(OrderedDict(img_name=img_name,
                                              img=inf_sample['sat_img'].shape,
                                              img_min_val=np.min(inf_sample['sat_img']),
                                              img_max_val=np.max(inf_sample['sat_img'])))

                input_band_count = inf_sample['sat_img'].shape[2] + MetaSegmentationDataset.get_meta_layer_count(meta_map)
                if input_band_count > num_bands:  # TODO: move as new function in utils.verifications
                    # FIXME: Following statements should be reconsidered to better manage inconsistencies between
                    #  provided number of band and image number of band.
                    warnings.warn(f"Input image has more band than the number provided in the yaml file ({num_bands}). "
                                  f"Will use the first {num_bands} bands of the input image.")
                    inf_sample['sat_img'] = inf_sample['sat_img'][:, :, 0:num_bands]
                    print(f"Input image's new shape: {inf_sample['sat_img'].shape}")

                elif input_band_count < num_bands:
                    warnings.warn(f"Skipping image: The number of bands requested in the yaml file ({num_bands})"
                                  f"can not be larger than the number of band in the input image ({input_band_count}).")
                    continue

                # START INFERENCES ON SUB-IMAGES
                sem_seg_results_per_class = sem_seg_inference(model,
                                                              inf_sample['sat_img'],
                                                              nbr_pix_overlap,
                                                              chunk_size,
                                                              num_classes_corrected,
                                                              device,
                                                              meta_map,
                                                              inf_sample['metadata'],
                                                              output_path=working_folder,
                                                              index=_tqdm.n,
                                                              debug=debug)

                # CREATE GEOTIF FROM METADATA OF ORIGINAL IMAGE
                tqdm.write(f'Saving inference...\n')
                if get_key_def('heatmaps', params['inference'], False):
                    tqdm.write(f'Heatmaps will be saved.\n')
                vis(params, inf_sample['sat_img'], sem_seg_results_per_class, working_folder, inference_input_path=local_img, debug=debug)

                tqdm.write(f"\n\nSemantic segmentation of image {img_name} completed\n\n")
                if bucket:
                    bucket.upload_file(inference_image, os.path.join(working_folder, f"{img_name.split('.')[0]}_inference.tif"))
    else:
        raise ValueError(
            f"The task should be either classification or segmentation. The provided value is {params['global']['task']}")

    time_elapsed = time.time() - since
    print('Inference completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
Exemplo n.º 3
0
def main(params):
    """
    Identify the class to which each image belongs.
    :param params: (dict) Parameters found in the yaml config file.

    """
    # SET BASIC VARIABLES AND PATHS
    since = time.time()

    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(f'Debug mode activated. Some debug features may mobilize extra disk space and cause delays in execution.')

    num_classes = params['global']['num_classes']
    if params['global']['task'] == 'segmentation':
        # assume background is implicitly needed (makes no sense to predict with one class, for example.)
        # this will trigger some warnings elsewhere, but should succeed nonetheless
        num_classes_corrected = num_classes + 1 # + 1 for background # FIXME temporary patch for num_classes problem.
    elif params['global']['task'] == 'classification':
        num_classes_corrected = num_classes

    chunk_size = get_key_def('chunk_size', params['inference'], 512)
    overlap = get_key_def('overlap', params['inference'], 10)
    nbr_pix_overlap = int(math.floor(overlap / 100 * chunk_size))
    num_bands = params['global']['number_of_bands']

    img_dir_or_csv = params['inference']['img_dir_or_csv_file']

    default_working_folder = Path(params['inference']['state_dict_path']).parent.joinpath(f'inference_{num_bands}bands')
    working_folder = Path(get_key_def('working_folder', params['inference'], default_working_folder)) # TODO: remove working_folder parameter in all templates
    Path.mkdir(working_folder, exist_ok=True)
    print(f'Inferences will be saved to: {working_folder}\n\n')

    bucket = None
    bucket_file_cache = []
    bucket_name = params['global']['bucket_name']

    # CONFIGURE MODEL
    model, state_dict_path, model_name = net(params, num_channels=num_classes_corrected, inference=True)

    num_devices = params['global']['num_gpus'] if params['global']['num_gpus'] else 0
    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    lst_device_ids = get_device_ids(num_devices) if torch.cuda.is_available() else []
    device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.is_available() and lst_device_ids else 'cpu')

    if lst_device_ids:
        print(f"Number of cuda devices requested: {num_devices}. Cuda devices available: {lst_device_ids}. Using {lst_device_ids[0]}\n\n")
    else:
        warnings.warn(f"No Cuda device available. This process will only run on CPU")

    try:
        model.to(device)
    except RuntimeError:
        print(f"Unable to use device. Trying device 0")
        device = torch.device(f'cuda:0' if torch.cuda.is_available() and lst_device_ids else 'cpu')
        model.to(device)

    if bucket_name:
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(bucket_name)
        if img_dir_or_csv.endswith('.csv'):
            bucket.download_file(img_dir_or_csv, 'img_csv_file.csv')
            list_img = read_csv('img_csv_file.csv', inference=True)
        else:
            raise NotImplementedError(
                'Specify a csv file containing images for inference. Directory input not implemented yet')
    else:
        if img_dir_or_csv.endswith('.csv'):
            list_img = read_csv(img_dir_or_csv, inference=True)
        else:
            img_dir = Path(img_dir_or_csv)
            assert img_dir.is_dir(), f'Could not find directory "{img_dir_or_csv}"'
            list_img_paths = sorted(img_dir.glob('*.tif'))  # FIXME: what if .tif is in caps (.TIF) ?
            list_img = []
            for img_path in list_img_paths:
                img = {}
                img['tif'] = img_path
                list_img.append(img)
            assert len(list_img) >= 0, f'No .tif files found in {img_dir_or_csv}'

    if params['global']['task'] == 'classification':
        classifier(params, list_img, model, device, working_folder)  # FIXME: why don't we load from checkpoint in classification?

    elif params['global']['task'] == 'segmentation':
        if bucket:
            bucket.download_file(state_dict_path, "saved_model.pth.tar")
            model, _ = load_from_checkpoint("saved_model.pth.tar", model, inference=True)
        else:
            model, _ = load_from_checkpoint(state_dict_path, model, inference=True)

        with tqdm(list_img, desc='image list', position=0) as _tqdm:
            for img in _tqdm:
                img_name = Path(img['tif']).name
                if bucket:
                    local_img = f"Images/{img_name}"
                    bucket.download_file(img['tif'], local_img)
                    inference_image = f"Classified_Images/{img_name.split('.')[0]}_inference.tif"
                    if img['meta']:
                        if img['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(img['meta'])
                            bucket.download_file(img['meta'], img['meta'].split('/')[-1])
                        img['meta'] = img['meta'].split('/')[-1]
                else:
                    local_img = Path(img['tif'])
                    inference_image = working_folder.joinpath(f"{img_name.split('.')[0]}_inference.tif")

                assert local_img.is_file(), f"Could not open raster file at {local_img}"

                scale = get_key_def('scale_data', params['global'], None)
                with rasterio.open(local_img, 'r') as raster:

                    np_input_image = image_reader_as_array(input_image=raster,
                                                           scale=scale,
                                                           aux_vector_file=get_key_def('aux_vector_file',
                                                                                       params['global'], None),
                                                           aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                                                         params['global'], None),
                                                           aux_vector_ids=get_key_def('aux_vector_ids',
                                                                                      params['global'], None),
                                                           aux_vector_dist_maps=get_key_def('aux_vector_dist_maps',
                                                                                            params['global'], True),
                                                           aux_vector_scale=get_key_def('aux_vector_scale',
                                                                                        params['global'], None))

                meta_map, metadata = get_key_def("meta_map", params["global"], {}), None
                if meta_map:
                    assert img['meta'] is not None and isinstance(img['meta'], str) and os.path.isfile(img['meta']), \
                        "global configuration requested metadata mapping onto loaded samples, but raster did not have available metadata"
                    metadata = read_parameters(img['meta'])

                if debug:
                    _tqdm.set_postfix(OrderedDict(img_name=img_name,
                                                  img=np_input_image.shape,
                                                  img_min_val=np.min(np_input_image),
                                                  img_max_val=np.max(np_input_image)))

                input_band_count = np_input_image.shape[2] + MetaSegmentationDataset.get_meta_layer_count(meta_map)
                if input_band_count > params['global']['number_of_bands']:
                    # FIXME: Following statements should be reconsidered to better manage inconsistencies between
                    #  provided number of band and image number of band.
                    warnings.warn(f"Input image has more band than the number provided in the yaml file ({params['global']['number_of_bands']}). "
                                  f"Will use the first {params['global']['number_of_bands']} bands of the input image.")
                    np_input_image = np_input_image[:, :, 0:params['global']['number_of_bands']]
                    print(f"Input image's new shape: {np_input_image.shape}")

                elif input_band_count < params['global']['number_of_bands']:
                    warnings.warn(f"Skipping image: The number of bands requested in the yaml file ({params['global']['number_of_bands']})"
                                  f"can not be larger than the number of band in the input image ({input_band_count}).")
                    continue

                # START INFERENCES ON SUB-IMAGES
                sem_seg_results_per_class = sem_seg_inference(model, np_input_image, nbr_pix_overlap, chunk_size, num_classes_corrected,
                                                    device, meta_map, metadata, output_path=working_folder, index=_tqdm.n, debug=debug)

                # CREATE GEOTIF FROM METADATA OF ORIGINAL IMAGE
                tqdm.write(f'Saving inference...\n')
                if get_key_def('heatmaps', params['inference'], False):
                    tqdm.write(f'Heatmaps will be saved.\n')
                vis(params, np_input_image, sem_seg_results_per_class, working_folder, inference_input_path=local_img, debug=debug)

                tqdm.write(f"\n\nSemantic segmentation of image {img_name} completed\n\n")
                if bucket:
                    bucket.upload_file(inference_image, os.path.join(working_folder, f"{img_name.split('.')[0]}_inference.tif"))
    else:
        raise ValueError(
            f"The task should be either classification or segmentation. The provided value is {params['global']['task']}")

    time_elapsed = time.time() - since
    print('Inference completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
Exemplo n.º 4
0
def create_csv():
    """
    Creates samples from the input images for the pixel_inventory function

    """
    prep_csv_path = params['sample']['prep_csv_file']
    dist_samples = params['sample']['samples_dist']
    sample_size = params['global']['samples_size']
    data_path = params['global']['data_path']
    Path.mkdir(Path(data_path), exist_ok=True)
    num_classes = params['global']['num_classes']
    data_prep_csv = read_csv(prep_csv_path)

    csv_prop_data = params['global']['data_path'] + '/prop_data.csv'
    if os.path.isfile(csv_prop_data):
        os.remove(csv_prop_data)

    with tqdm(data_prep_csv) as _tqdm:
        for info in _tqdm:

            _tqdm.set_postfix(OrderedDict(file=f'{info["tif"]}', sample_size=params['global']['samples_size']))

            # Validate the number of class in the vector file
            validate_num_classes(info['gpkg'], num_classes, info['attribute_name'])

            assert os.path.isfile(info['tif']), f"could not open raster file at {info['tif']}"
            with rasterio.open(info['tif'], 'r') as raster:

                # Burn vector file in a raster file
                np_label_raster = vector_to_raster(vector_file=info['gpkg'],
                                                   input_image=raster,
                                                   attribute_name=info['attribute_name'],
                                                   fill=get_key_def('ignore_idx', get_key_def('training', params, {}),
                                                                    0))

                # Read the input raster image
                np_input_image = image_reader_as_array(input_image=raster,
                                                       aux_vector_file=get_key_def('aux_vector_file', params['global'],
                                                                                   None),
                                                       aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                                                     params['global'], None),
                                                       aux_vector_ids=get_key_def('aux_vector_ids', params['global'],
                                                                                  None),
                                                       aux_vector_dist_maps=get_key_def('aux_vector_dist_maps',
                                                                                        params['global'], True),
                                                       aux_vector_dist_log=get_key_def('aux_vector_dist_log',
                                                                                       params['global'], True),
                                                       aux_vector_scale=get_key_def('aux_vector_scale',
                                                                                    params['global'], None))
                # Mask the zeros from input image into label raster.
                if params['sample']['mask_reference']:
                    np_label_raster = images_to_samples.mask_image(np_input_image, np_label_raster)

                np_label_raster = np.reshape(np_label_raster, (np_label_raster.shape[0], np_label_raster.shape[1], 1))

                h, w, num_bands = np_input_image.shape

                # half tile padding
                half_tile = int(sample_size / 2)
                pad_label_array = np.pad(np_label_raster, ((half_tile, half_tile), (half_tile, half_tile), (0, 0)),
                                         mode='constant')

                for row in range(0, h, dist_samples):
                    for column in range(0, w, dist_samples):
                        target = np.squeeze(pad_label_array[row:row + sample_size, column:column + sample_size, :], axis=2)

                        pixel_inventory(target, sample_size, params['global']['num_classes'] + 1,
                                        params['global']['data_path'], info['dataset'])
def main(params):
    """
    Training and validation datasets preparation.
    :param params: (dict) Parameters found in the yaml config file.

    """
    now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    bucket_file_cache = []

    assert params['global'][
        'task'] == 'segmentation', f"images_to_samples.py isn't necessary when performing classification tasks"

    # SET BASIC VARIABLES AND PATHS. CREATE OUTPUT FOLDERS.
    bucket_name = params['global']['bucket_name']
    data_path = Path(params['global']['data_path'])
    Path.mkdir(data_path, exist_ok=True, parents=True)
    csv_file = params['sample']['prep_csv_file']
    val_percent = params['sample']['val_percent']
    samples_size = params["global"]["samples_size"]
    overlap = params["sample"]["overlap"]
    min_annot_perc = params['sample']['sampling']['map']
    num_bands = params['global']['number_of_bands']
    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(f'Debug mode activate. Execution may take longer...')

    final_samples_folder = None
    if bucket_name:
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(bucket_name)
        bucket.download_file(csv_file, 'samples_prep.csv')
        list_data_prep = read_csv('samples_prep.csv')
        if data_path:
            final_samples_folder = os.path.join(data_path, "samples")
        else:
            final_samples_folder = "samples"
        samples_folder = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'  # TODO: validate this is preferred name structure

    else:
        list_data_prep = read_csv(csv_file)
        samples_folder = data_path.joinpath(
            f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'
        )

    if samples_folder.is_dir():
        warnings.warn(
            f'Data path exists: {samples_folder}. Suffix will be added to directory name.'
        )
        samples_folder = Path(str(samples_folder) + '_' + now)
    else:
        tqdm.write(f'Writing samples to {samples_folder}')
    Path.mkdir(samples_folder, exist_ok=False
               )  # FIXME: what if we want to append samples to existing hdf5?
    tqdm.write(f'Samples will be written to {samples_folder}\n\n')

    tqdm.write(
        f'\nSuccessfully read csv file: {Path(csv_file).stem}\nNumber of rows: {len(list_data_prep)}\nCopying first entry:\n{list_data_prep[0]}\n'
    )
    ignore_index = get_key_def('ignore_index', params['training'], -1)

    for info in tqdm(list_data_prep,
                     position=0,
                     desc=f'Asserting existence of tif and gpkg files in csv'):
        assert Path(info['tif']).is_file(
        ), f'Could not locate "{info["tif"]}". Make sure file exists in this directory.'
        assert Path(info['gpkg']).is_file(
        ), f'Could not locate "{info["gpkg"]}". Make sure file exists in this directory.'
    if debug:
        for info in tqdm(
                list_data_prep,
                position=0,
                desc=f"Validating presence of {params['global']['num_classes']} "
                f"classes in attribute \"{info['attribute_name']}\" for vector "
                f"file \"{Path(info['gpkg']).stem}\""):
            validate_num_classes(info['gpkg'], params['global']['num_classes'],
                                 info['attribute_name'], ignore_index)
        with tqdm(list_data_prep,
                  position=0,
                  desc=f"Checking validity of features in vector files"
                  ) as _tqdm:
            invalid_features = {}
            for info in _tqdm:
                # Extract vector features to burn in the raster image
                with fiona.open(
                        info['gpkg'],
                        'r') as src:  # TODO: refactor as independent function
                    lst_vector = [vector for vector in src]
                shapes = lst_ids(list_vector=lst_vector,
                                 attr_name=info['attribute_name'])
                for index, item in enumerate(
                        tqdm([v for vecs in shapes.values() for v in vecs],
                             leave=False,
                             position=1)):
                    # geom must be a valid GeoJSON geometry type and non-empty
                    geom, value = item
                    geom = getattr(geom, '__geo_interface__', None) or geom
                    if not is_valid_geom(geom):
                        gpkg_stem = str(Path(info['gpkg']).stem)
                        if gpkg_stem not in invalid_features.keys(
                        ):  # create key with name of gpkg
                            invalid_features[gpkg_stem] = []
                        if lst_vector[index]["id"] not in invalid_features[
                                gpkg_stem]:  # ignore feature is already appended
                            invalid_features[gpkg_stem].append(
                                lst_vector[index]["id"])
            assert len(
                invalid_features.values()
            ) == 0, f'Invalid geometry object(s) for "gpkg:ids": \"{invalid_features}\"'

    number_samples = {'trn': 0, 'val': 0, 'tst': 0}
    number_classes = 0

    # 'sampling' ordereddict validation
    check_sampling_dict()

    pixel_classes = {}
    # creates pixel_classes dict and keys
    for i in range(0, params['global']['num_classes'] + 1):
        pixel_classes.update({i: 0})
    pixel_classes.update(
        {ignore_index: 0}
    )  # FIXME: pixel_classes dict needs to be populated with classes obtained from target

    trn_hdf5, val_hdf5, tst_hdf5 = create_files_and_datasets(
        params, samples_folder)

    # For each row in csv: (1) burn vector file to raster, (2) read input raster image, (3) prepare samples
    with tqdm(list_data_prep,
              position=0,
              leave=False,
              desc=f'Preparing samples') as _tqdm:
        for info in _tqdm:
            _tqdm.set_postfix(
                OrderedDict(tif=f'{Path(info["tif"]).stem}',
                            sample_size=params['global']['samples_size']))
            try:
                if bucket_name:
                    bucket.download_file(
                        info['tif'], "Images/" + info['tif'].split('/')[-1])
                    info['tif'] = "Images/" + info['tif'].split('/')[-1]
                    if info['gpkg'] not in bucket_file_cache:
                        bucket_file_cache.append(info['gpkg'])
                        bucket.download_file(info['gpkg'],
                                             info['gpkg'].split('/')[-1])
                    info['gpkg'] = info['gpkg'].split('/')[-1]
                    if info['meta']:
                        if info['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(info['meta'])
                            bucket.download_file(info['meta'],
                                                 info['meta'].split('/')[-1])
                        info['meta'] = info['meta'].split('/')[-1]

                with rasterio.open(info['tif'], 'r') as raster:
                    # Burn vector file in a raster file
                    np_label_raster = vector_to_raster(
                        vector_file=info['gpkg'],
                        input_image=raster,
                        attribute_name=info['attribute_name'],
                        fill=get_key_def('ignore_idx',
                                         get_key_def('training', params, {}),
                                         0))
                    # Read the input raster image
                    np_input_image = image_reader_as_array(
                        input_image=raster,
                        scale=get_key_def('scale_data', params['global'],
                                          None),
                        aux_vector_file=get_key_def('aux_vector_file',
                                                    params['global'], None),
                        aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                      params['global'], None),
                        aux_vector_ids=get_key_def('aux_vector_ids',
                                                   params['global'], None),
                        aux_vector_dist_maps=get_key_def(
                            'aux_vector_dist_maps', params['global'], True),
                        aux_vector_dist_log=get_key_def(
                            'aux_vector_dist_log', params['global'], True),
                        aux_vector_scale=get_key_def('aux_vector_scale',
                                                     params['global'], None))

                # Mask the zeros from input image into label raster.
                if params['sample']['mask_reference']:
                    np_label_raster = mask_image(np_input_image,
                                                 np_label_raster)

                if info['dataset'] == 'trn':
                    out_file = trn_hdf5
                    val_file = val_hdf5
                elif info['dataset'] == 'tst':
                    out_file = tst_hdf5
                else:
                    raise ValueError(
                        f"Dataset value must be trn or val or tst. Provided value is {info['dataset']}"
                    )

                meta_map, metadata = get_key_def("meta_map", params["global"],
                                                 {}), None
                if info['meta'] is not None and isinstance(
                        info['meta'], str) and Path(info['meta']).is_file():
                    metadata = read_parameters(info['meta'])

                # FIXME: think this through. User will have to calculate the total number of bands including meta layers and
                #  specify it in yaml. Is this the best approach? What if metalayers are added on the fly ?
                input_band_count = np_input_image.shape[
                    2] + MetaSegmentationDataset.get_meta_layer_count(meta_map)
                # FIXME: could this assert be done before getting into this big for loop?
                assert input_band_count == num_bands, \
                    f"The number of bands in the input image ({input_band_count}) and the parameter" \
                    f"'number_of_bands' in the yaml file ({params['global']['number_of_bands']}) should be identical"

                np_label_raster = np.reshape(
                    np_label_raster,
                    (np_label_raster.shape[0], np_label_raster.shape[1], 1))
                number_samples, number_classes = samples_preparation(
                    np_input_image, np_label_raster, samples_size, overlap,
                    number_samples, number_classes, out_file, val_percent,
                    val_file, info['dataset'], pixel_classes, metadata)

                _tqdm.set_postfix(OrderedDict(number_samples=number_samples))
                out_file.flush()
            except Exception as e:
                warnings.warn(
                    f'An error occurred while preparing samples with "{Path(info["tif"]).stem}" (tiff) and '
                    f'{Path(info["gpkg"]).stem} (gpkg). Error: "{e}"')
                continue

    trn_hdf5.close()
    val_hdf5.close()
    tst_hdf5.close()

    pixel_total = 0
    # adds up the number of pixels for each class in pixel_classes dict
    for i in pixel_classes:
        pixel_total += pixel_classes[i]

    # prints the proportion of pixels of each class for the samples created
    for i in pixel_classes:
        print('Pixels from class', i, ':',
              round((pixel_classes[i] / pixel_total) * 100, 1), '%')

    print("Number of samples created: ", number_samples)

    if bucket_name and final_samples_folder:
        print('Transfering Samples to the bucket')
        bucket.upload_file(samples_folder + "/trn_samples.hdf5",
                           final_samples_folder + '/trn_samples.hdf5')
        bucket.upload_file(samples_folder + "/val_samples.hdf5",
                           final_samples_folder + '/val_samples.hdf5')
        bucket.upload_file(samples_folder + "/tst_samples.hdf5",
                           final_samples_folder + '/tst_samples.hdf5')

    print("End of process")
Exemplo n.º 6
0
def segmentation_with_smoothing(raster, clip_gpkg, model, sample_size, overlap,
                                num_bands, device):
    # switch to evaluate mode
    model.eval()
    img_array, input_image, dataset_nodata = image_reader_as_array(
        input_image=raster, clip_gpkg=clip_gpkg)
    metadata = add_metadata_from_raster_to_sample(img_array,
                                                  input_image,
                                                  meta_map=None,
                                                  raster_info=None)
    h, w, bands = img_array.shape
    assert num_bands <= bands, f"Num of specified bands is not compatible with image shape {img_array.shape}"
    if num_bands < bands:
        img_array = img_array[:, :, :num_bands]

    padding = int(round(sample_size * (1 - 1.0 / overlap)))
    padded_img = pad(img_array, padding=padding, fill=0)
    WINDOW_SPLINE_2D = _window_2D(window_size=sample_size, power=1)
    WINDOW_SPLINE_2D = np.moveaxis(WINDOW_SPLINE_2D, 2, 0)
    step = int(sample_size / overlap)
    h_, w_ = padded_img.shape[:2]
    pred_img = np.empty((h_, w_), dtype=np.uint8)
    for row in tqdm(range(0, h_ - sample_size + 1, step),
                    position=1,
                    leave=False,
                    desc='Inferring rows'):
        with tqdm(range(0, w_ - sample_size + 1, step),
                  position=2,
                  leave=False,
                  desc='Inferring columns') as _tqdm:
            for col in _tqdm:
                sample = {'sat_img': None, 'metadata': None}
                sample['metadata'] = metadata
                totensor_transform = augmentation.compose_transforms(
                    params, dataset="tst", type='totensor')
                sub_images = padded_img[row:row + sample_size,
                                        col:col + sample_size, :]
                sample['sat_img'] = sub_images
                sample = totensor_transform(sample)
                inputs = sample['sat_img'].unsqueeze_(0)
                inputs = inputs.to(device)

                if inputs.shape[1] == 4 and any(
                        "module.modelNIR" in s
                        for s in model.state_dict().keys()):
                    ############################
                    # Test Implementation of the NIR
                    ############################
                    # Init NIR   TODO: make a proper way to read the NIR channel
                    #                  and put an option to be able to give the idex of the NIR channel
                    # Extract the NIR channel -> [batch size, H, W] since it's only one channel
                    inputs_NIR = inputs[:, -1, ...]
                    # add a channel to get the good size -> [:, 1, :, :]
                    inputs_NIR.unsqueeze_(1)
                    # take out the NIR channel and take only the RGB for the inputs
                    inputs = inputs[:, :-1, ...]
                    # Suggestion of implementation
                    #inputs_NIR = data['NIR'].to(device)
                    inputs = [inputs, inputs_NIR]
                    #outputs = model(inputs, inputs_NIR)
                    ############################
                    # End of the test implementation module
                    ############################

                outputs = model(inputs)
                # torchvision models give output in 'out' key.
                # May cause problems in future versions of torchvision.
                if isinstance(outputs,
                              OrderedDict) and 'out' in outputs.keys():
                    outputs = outputs['out']
                outputs = F.softmax(outputs,
                                    dim=1).squeeze(dim=0).cpu().numpy()
                outputs = WINDOW_SPLINE_2D * outputs
                outputs = outputs.argmax(axis=0)
                pred_img[row:row + sample_size,
                         col:col + sample_size] = outputs
    pred_img = pred_img[padding:-padding, padding:-padding]
    return pred_img[:h, :w]
Exemplo n.º 7
0
def segmentation(raster, clip_gpkg, model, sample_size, num_bands, device):
    # switch to evaluate mode
    model.eval()
    img_array, input_image, dataset_nodata = image_reader_as_array(
        input_image=raster, clip_gpkg=clip_gpkg)
    metadata = add_metadata_from_raster_to_sample(img_array,
                                                  input_image,
                                                  meta_map=None,
                                                  raster_info=None)
    h, w, bands = img_array.shape
    assert num_bands <= bands, f"Num of specified bands is not compatible with image shape {img_array.shape}"
    if num_bands < bands:
        img_array = img_array[:, :, :num_bands]
    h_ = sample_size * math.ceil(h / sample_size)
    w_ = sample_size * math.ceil(w / sample_size)
    pred_img = np.empty((h_, w_), dtype=np.uint8)
    for row in tqdm(range(0, h, sample_size),
                    position=1,
                    leave=False,
                    desc='Inferring rows'):
        with tqdm(range(0, w, sample_size),
                  position=2,
                  leave=False,
                  desc='Inferring columns') as _tqdm:
            for column in _tqdm:
                sample = {'sat_img': None, 'metadata': None}
                sample['metadata'] = metadata
                totensor_transform = augmentation.compose_transforms(
                    params, dataset="tst", type='totensor')
                sub_images = img_array[row:row + sample_size,
                                       column:column + sample_size, :]
                sub_images_row = sub_images.shape[0]
                sub_images_col = sub_images.shape[1]

                if sub_images_row < sample_size or sub_images_col < sample_size:
                    padding = pad_diff(actual_height=sub_images_row,
                                       actual_width=sub_images_col,
                                       desired_shape=sample_size)
                    sub_images = pad(
                        sub_images, padding, fill=0
                    )  # FIXME combine pad and pad_diff into one function
                sample['sat_img'] = sub_images
                sample = totensor_transform(sample)
                inputs = sample['sat_img'].unsqueeze_(0)
                inputs = inputs.to(device)

                if inputs.shape[1] == 4 and any(
                        "module.modelNIR" in s
                        for s in model.state_dict().keys()):
                    ############################
                    # Test Implementation of the NIR
                    ############################
                    # Init NIR   TODO: make a proper way to read the NIR channel
                    #                  and put an option to be able to give the idex of the NIR channel
                    # Extract the NIR channel -> [batch size, H, W] since it's only one channel
                    inputs_NIR = inputs[:, -1, ...]
                    # add a channel to get the good size -> [:, 1, :, :]
                    inputs_NIR.unsqueeze_(1)
                    # take out the NIR channel and take only the RGB for the inputs
                    inputs = inputs[:, :-1, ...]
                    # Suggestion of implementation
                    #inputs_NIR = data['NIR'].to(device)
                    inputs = [inputs, inputs_NIR]
                    #outputs = model(inputs, inputs_NIR)
                    ############################
                    # End of the test implementation module
                    ############################

                outputs = model(inputs)
                # torchvision models give output in 'out' key. May cause problems in future versions of torchvision.
                if isinstance(outputs,
                              OrderedDict) and 'out' in outputs.keys():
                    outputs = outputs['out']
                outputs = F.softmax(
                    outputs, dim=1).argmax(dim=1).squeeze(dim=0).cpu().numpy()

                pred_img[row:row + sample_size,
                         column:column + sample_size] = outputs

    return pred_img[:h, :w]
Exemplo n.º 8
0
def main(cfg: DictConfig) -> None:
    """
    Function that create training, validation and testing datasets preparation.

    1. Read csv file and validate existence of all input files and GeoPackages.
    2. Do the following verifications:
        1. Assert number of bands found in raster is equal to desired number
           of bands.
        2. Check that `num_classes` is equal to number of classes detected in
           the specified attribute for each GeoPackage.
           Warning: this validation will not succeed if a Geopackage
                    contains only a subset of `num_classes` (e.g. 3 of 4).
        3. Assert Coordinate reference system between raster and gpkg match.
    3. Read csv file and for each line in the file, do the following:
        1. Read input image as array with utils.readers.image_reader_as_array().
            - If gpkg's extent is smaller than raster's extent,
              raster is clipped to gpkg's extent.
            - If gpkg's extent is bigger than raster's extent,
              gpkg is clipped to raster's extent.
        2. Convert GeoPackage vector information into the "label" raster with
           utils.utils.vector_to_raster(). The pixel value is determined by the
           attribute in the csv file.
        3. Create a new raster called "label" with the same properties as the
           input image.
        4. Read metadata and add to input as new bands (*more details to come*).
        5. Crop the arrays in smaller samples of the size `samples_size` of
           `your_conf.yaml`. Visual representation of this is provided at
            https://medium.com/the-downlinq/broad-area-satellite-imagery-semantic-segmentation-basiss-4a7ea2c8466f
        6. Write samples from input image and label into the "val", "trn" or
           "tst" hdf5 file, depending on the value contained in the csv file.
            Refer to samples_preparation().

    -------
    :param cfg: (dict) Parameters found in the yaml config file.
    """
    try:
        import boto3
    except ModuleNotFoundError:
        logging.warning(
            "\nThe boto3 library couldn't be imported. Ignore if not using AWS s3 buckets",
            ImportWarning)
        pass

    # PARAMETERS
    num_classes = len(cfg.dataset.classes_dict.keys())
    num_bands = len(cfg.dataset.modalities)
    modalities = read_modalities(
        cfg.dataset.modalities
    )  # TODO add the Victor module to manage the modalities
    debug = cfg.debug
    task = cfg.task.task_name

    # RAW DATA PARAMETERS
    # Data folder
    try:
        # check if the folder exist
        my_data_path = Path(cfg.dataset.raw_data_dir).resolve(strict=True)
        logging.info("\nImage directory used '{}'".format(my_data_path))
        data_path = Path(my_data_path)
    except FileNotFoundError:
        raise logging.critical(
            "\nImage directory '{}' doesn't exist, please change the path".
            format(cfg.dataset.raw_data_dir))
    # CSV file
    try:
        my_csv_path = Path(cfg.dataset.raw_data_csv).resolve(strict=True)
        # path.exists(cfg.dataset.raw_data_csv)
        logging.info("\nImage csv: '{}'".format(my_csv_path))
        csv_file = my_csv_path
    except FileNotFoundError:
        raise logging.critical(
            "\nImage csv '{}' doesn't exist, please change the path".format(
                cfg.dataset.raw_data_csv))
    # HDF5 data
    try:
        my_hdf5_path = Path(str(
            cfg.dataset.sample_data_dir)).resolve(strict=True)
        logging.info("\nThe HDF5 directory used '{}'".format(my_hdf5_path))
        Path.mkdir(Path(my_hdf5_path), exist_ok=True, parents=True)
    except FileNotFoundError:
        logging.info(
            "\nThe HDF5 directory '{}' doesn't exist, please change the path.".
            format(cfg.dataset.sample_data_dir) +
            "\nFor now the HDF5 directory use will be change for '{}'".format(
                data_path))
        cfg.general.sample_data_dir = str(data_path)

    # SAMPLE PARAMETERS
    samples_size = get_key_def('input_dim',
                               cfg['dataset'],
                               default=256,
                               expected_type=int)
    overlap = get_key_def('overlap', cfg['dataset'], default=0)
    min_annot_perc = get_key_def('min_annot_perc', cfg['dataset'], default=0)
    val_percent = get_key_def('train_val_percent', cfg['dataset'],
                              default=0.3)['val'] * 100
    samples_folder_name = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}' \
                          f'_{num_bands}bands_{cfg.general.project_name}'
    samples_dir = data_path.joinpath(samples_folder_name)
    if samples_dir.is_dir():
        if debug:
            # Move existing data folder with a random suffix.
            last_mod_time_suffix = datetime.fromtimestamp(
                samples_dir.stat().st_mtime).strftime('%Y%m%d-%H%M%S')
            shutil.move(
                samples_dir,
                data_path.joinpath(
                    f'{str(samples_dir)}_{last_mod_time_suffix}'))
        else:
            logging.critical(
                f'Data path exists: {samples_dir}. Remove it or use a different experiment_name.'
            )
            raise FileExistsError()
    Path.mkdir(samples_dir, exist_ok=False
               )  # TODO: what if we want to append samples to existing hdf5?

    # LOGGING PARAMETERS  TODO see logging yaml
    experiment_name = cfg.general.project_name
    # mlflow_uri = get_key_def('mlflow_uri', params['global'], default="./mlruns")

    # OTHER PARAMETERS
    metadata = None
    meta_map = {}  # TODO get_key_def('meta_map', params['global'], default={})
    # TODO class_prop get_key_def('class_proportion', params['sample']['sampling_method'], None, expected_type=dict)
    class_prop = None
    mask_reference = False  # TODO get_key_def('mask_reference', params['sample'], default=False, expected_type=bool)
    # set dontcare (aka ignore_index) value
    dontcare = cfg.dataset.ignore_index if cfg.dataset.ignore_index is not None else -1
    if dontcare == 0:
        logging.warning(
            "\nThe 'dontcare' value (or 'ignore_index') used in the loss function cannot be zero."
            " All valid class indices should be consecutive, and start at 0. The 'dontcare' value"
            " will be remapped to -1 while loading the dataset, and inside the config from now on."
        )
        dontcare = -1
    # Assert that all items in target_ids are integers (ex.: single-class samples from multi-class label)
    targ_ids = None  # TODO get_key_def('target_ids', params['sample'], None, expected_type=List)
    if targ_ids is list:
        for item in targ_ids:
            if not isinstance(item, int):
                raise logging.critical(
                    ValueError(
                        f'\nTarget id "{item}" in target_ids is {type(item)}, expected int.'
                    ))

    # OPTIONAL
    use_stratification = cfg.dataset.use_stratification if cfg.dataset.use_stratification is not None else False
    if use_stratification:
        stratd = {
            'trn': {
                'total_pixels': 0,
                'total_counts': {},
                'total_props': {}
            },
            'val': {
                'total_pixels': 0,
                'total_counts': {},
                'total_props': {}
            },
            'strat_factor': cfg['dataset']['use_stratification']
        }
    else:
        stratd = None

    # ADD GIT HASH FROM CURRENT COMMIT TO PARAMETERS (if available and parameters will be saved to hdf5s).
    with open_dict(cfg):
        cfg.general.git_hash = get_git_hash()

    # AWS TODO
    bucket_name = cfg.AWS.bucket_name
    if bucket_name:
        final_samples_folder = None
        bucket_name = cfg.AWS.bucket_name
        bucket_file_cache = []
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(bucket_name)
        bucket.download_file(csv_file, 'samples_prep.csv')
        list_data_prep = read_csv('samples_prep.csv', data_path)
    else:
        list_data_prep = read_csv(csv_file, data_path)

    # IF DEBUG IS ACTIVATE
    if debug:
        logging.warning(
            f'\nDebug mode activated. Some debug features may mobilize extra disk space and cause delays in execution.'
        )

    # VALIDATION: (1) Assert num_classes parameters == num actual classes in gpkg and (2) check CRS match (tif and gpkg)
    valid_gpkg_set = set()
    for info in tqdm(list_data_prep, position=0):
        validate_raster(info['tif'], num_bands, meta_map)
        if info['gpkg'] not in valid_gpkg_set:
            gpkg_classes = validate_num_classes(
                info['gpkg'],
                num_classes,
                info['attribute_name'],
                dontcare,
                target_ids=targ_ids,
            )
            assert_crs_match(info['tif'], info['gpkg'])
            valid_gpkg_set.add(info['gpkg'])

    if debug:
        # VALIDATION (debug only): Checking validity of features in vector files
        for info in tqdm(
                list_data_prep,
                position=0,
                desc=f"Checking validity of features in vector files"):
            # TODO: make unit to test this with invalid features.
            invalid_features = validate_features_from_gpkg(
                info['gpkg'], info['attribute_name'])
            if invalid_features:
                logging.critical(
                    f"{info['gpkg']}: Invalid geometry object(s) '{invalid_features}'"
                )

    number_samples = {'trn': 0, 'val': 0, 'tst': 0}
    number_classes = 0

    # with open_dict(cfg):
    #     print(cfg)
    trn_hdf5, val_hdf5, tst_hdf5 = create_files_and_datasets(
        samples_size=samples_size,
        number_of_bands=num_bands,
        meta_map=meta_map,
        samples_folder=samples_dir,
        cfg=cfg)

    # creates pixel_classes dict and keys
    pixel_classes = {key: 0 for key in gpkg_classes}
    background_val = 0
    pixel_classes[background_val] = 0
    class_prop = validate_class_prop_dict(pixel_classes, class_prop)
    pixel_classes[dontcare] = 0

    # For each row in csv: (1) burn vector file to raster, (2) read input raster image, (3) prepare samples
    logging.info(
        f"\nPreparing samples \n  Samples_size: {samples_size} \n  Overlap: {overlap} "
        f"\n  Validation set: {val_percent} % of created training samples")
    for info in tqdm(list_data_prep, position=0, leave=False):
        try:
            if bucket_name:
                bucket.download_file(info['tif'],
                                     "Images/" + info['tif'].split('/')[-1])
                info['tif'] = "Images/" + info['tif'].split('/')[-1]
                if info['gpkg'] not in bucket_file_cache:
                    bucket_file_cache.append(info['gpkg'])
                    bucket.download_file(info['gpkg'],
                                         info['gpkg'].split('/')[-1])
                info['gpkg'] = info['gpkg'].split('/')[-1]
                if info['meta']:
                    if info['meta'] not in bucket_file_cache:
                        bucket_file_cache.append(info['meta'])
                        bucket.download_file(info['meta'],
                                             info['meta'].split('/')[-1])
                    info['meta'] = info['meta'].split('/')[-1]

            logging.info(f"\nReading as array: {info['tif']}")
            with rasterio.open(info['tif'], 'r') as raster:
                # 1. Read the input raster image
                np_input_image, raster, dataset_nodata = image_reader_as_array(
                    input_image=raster,
                    clip_gpkg=info['gpkg'],
                    aux_vector_file=get_key_def('aux_vector_file',
                                                cfg['dataset'], None),
                    aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                  cfg['dataset'], None),
                    aux_vector_ids=get_key_def('aux_vector_ids',
                                               cfg['dataset'], None),
                    aux_vector_dist_maps=get_key_def('aux_vector_dist_maps',
                                                     cfg['dataset'], True),
                    aux_vector_dist_log=get_key_def('aux_vector_dist_log',
                                                    cfg['dataset'], True),
                    aux_vector_scale=get_key_def('aux_vector_scale',
                                                 cfg['dataset'], None))

                # 2. Burn vector file in a raster file
                logging.info(
                    f"\nRasterizing vector file (attribute: {info['attribute_name']}): {info['gpkg']}"
                )
                np_label_raster = vector_to_raster(
                    vector_file=info['gpkg'],
                    input_image=raster,
                    out_shape=np_input_image.shape[:2],
                    attribute_name=info['attribute_name'],
                    fill=background_val,
                    target_ids=targ_ids
                )  # background value in rasterized vector.

                if dataset_nodata is not None:
                    # 3. Set ignore_index value in label array where nodata in raster (only if nodata across all bands)
                    np_label_raster[dataset_nodata] = dontcare

            if debug:
                out_meta = raster.meta.copy()
                np_image_debug = np_input_image.transpose(2, 0, 1).astype(
                    out_meta['dtype'])
                out_meta.update({
                    "driver": "GTiff",
                    "height": np_image_debug.shape[1],
                    "width": np_image_debug.shape[2]
                })
                out_tif = samples_dir / f"{Path(info['tif']).stem}_clipped.tif"
                logging.debug(f"Writing clipped raster to {out_tif}")
                with rasterio.open(out_tif, "w", **out_meta) as dest:
                    dest.write(np_image_debug)

                out_meta = raster.meta.copy()
                np_label_debug = np.expand_dims(
                    np_label_raster,
                    axis=2).transpose(2, 0, 1).astype(out_meta['dtype'])
                out_meta.update({
                    "driver": "GTiff",
                    "height": np_label_debug.shape[1],
                    "width": np_label_debug.shape[2],
                    'count': 1
                })
                out_tif = samples_dir / f"{Path(info['gpkg']).stem}_clipped.tif"
                logging.debug(f"\nWriting final rasterized gpkg to {out_tif}")
                with rasterio.open(out_tif, "w", **out_meta) as dest:
                    dest.write(np_label_debug)

            # Mask the zeros from input image into label raster.
            if mask_reference:
                np_label_raster = mask_image(np_input_image, np_label_raster)

            if info['dataset'] == 'trn':
                out_file = trn_hdf5
            elif info['dataset'] == 'tst':
                out_file = tst_hdf5
            else:
                raise logging.critical(
                    ValueError(
                        f"\nDataset value must be trn or tst. Provided value is {info['dataset']}"
                    ))
            val_file = val_hdf5

            metadata = add_metadata_from_raster_to_sample(
                sat_img_arr=np_input_image,
                raster_handle=raster,
                meta_map=meta_map,
                raster_info=info)
            # Save label's per class pixel count to image metadata
            metadata['source_label_bincount'] = {
                class_num: count
                for class_num, count in enumerate(
                    np.bincount(np_label_raster.clip(min=0).flatten()))
                if count > 0
            }  # TODO: add this to add_metadata_from[...] function?

            np_label_raster = np.reshape(
                np_label_raster,
                (np_label_raster.shape[0], np_label_raster.shape[1], 1))
            # 3. Prepare samples!
            number_samples, number_classes = samples_preparation(
                in_img_array=np_input_image,
                label_array=np_label_raster,
                sample_size=samples_size,
                overlap=overlap,
                samples_count=number_samples,
                num_classes=number_classes,
                samples_file=out_file,
                val_percent=val_percent,
                val_sample_file=val_file,
                dataset=info['dataset'],
                pixel_classes=pixel_classes,
                dontcare=dontcare,
                image_metadata=metadata,
                min_annot_perc=min_annot_perc,
                class_prop=class_prop,
                stratd=stratd)

            # logging.info(f'\nNumber of samples={number_samples}')
            out_file.flush()
        except OSError:
            logging.exception(
                f'\nAn error occurred while preparing samples with "{Path(info["tif"]).stem}" (tiff) and '
                f'{Path(info["gpkg"]).stem} (gpkg).')
            continue

    trn_hdf5.close()
    val_hdf5.close()
    tst_hdf5.close()

    pixel_total = 0
    # adds up the number of pixels for each class in pixel_classes dict
    for i in pixel_classes:
        pixel_total += pixel_classes[i]
    # calculate the proportion of pixels of each class for the samples created
    pixel_classes_dict = {}
    for i in pixel_classes:
        # prop = round((pixel_classes[i] / pixel_total) * 100, 1) if pixel_total > 0 else 0
        pixel_classes_dict[i] = round((pixel_classes[i] / pixel_total) *
                                      100, 1) if pixel_total > 0 else 0
    # prints the proportion of pixels of each class for the samples created
    msg_pixel_classes = "\n".join("Pixels from class {}: {}%".format(k, v)
                                  for k, v in pixel_classes_dict.items())
    logging.info("\n" + msg_pixel_classes)

    logging.info(f"\nNumber of samples created: {number_samples}")

    if bucket_name and final_samples_folder:  # FIXME: final_samples_folder always None in current implementation
        logging.info('\nTransfering Samples to the bucket')
        bucket.upload_file(samples_dir + "/trn_samples.hdf5",
                           final_samples_folder + '/trn_samples.hdf5')
        bucket.upload_file(samples_dir + "/val_samples.hdf5",
                           final_samples_folder + '/val_samples.hdf5')
        bucket.upload_file(samples_dir + "/tst_samples.hdf5",
                           final_samples_folder + '/tst_samples.hdf5')