예제 #1
0
def add_metadata_from_raster_to_sample(sat_img_arr: np.ndarray,
                                       raster_handle: dict, meta_map: dict,
                                       raster_info: dict) -> dict:
    """
    :param sat_img_arr: source image as array (opened with rasterio.read)
    :param meta_map: meta map parameter from yaml (global section)
    :param raster_info: info from raster as read with read_csv (except at inference)
    :return: Returns a metadata dictionary populated with info from source raster, including original csv line and
             histogram.
    """
    metadata_dict = {
        'name': raster_handle.name,
        'csv_info': raster_info,
        'source_raster_bincount': {}
    }
    assert 'dtype' in raster_handle.meta.keys(
    ), "\"dtype\" could not be found in source image metadata"
    metadata_dict.update(raster_handle.meta)
    if not metadata_dict['dtype'] in ["uint8", "uint16"]:
        warnings.warn(
            f"Datatype should be \"uint8\" or \"uint16\". Got \"{metadata_dict['dtype']}\". "
        )
        if sat_img_arr.min() >= 0 and sat_img_arr.max() <= 255:
            metadata_dict['dtype'] = "uint8"
        elif sat_img_arr.min() >= 0 and sat_img_arr.max() <= 65535:
            metadata_dict['dtype'] = "uint16"
        else:
            raise NotImplementedError(
                f"Min and max values of array ({[sat_img_arr.min(), sat_img_arr.max()]}) "
                f"are not contained in 8 bit nor 16 bit range. Datatype cannot be overwritten."
            )
    # Save bin count (i.e. histogram) to metadata
    assert isinstance(sat_img_arr, np.ndarray) and len(
        sat_img_arr.shape) == 3, f"Array should be 3-dimensional"
    for band_index in range(sat_img_arr.shape[2]):
        band = sat_img_arr[..., band_index]
        metadata_dict['source_raster_bincount'][f'band{band_index}'] = {
            count
            for count in np.bincount(band.flatten())
        }
    if meta_map and Path(raster_info['meta']).is_file():
        if not raster_info['meta'] is not None and isinstance(
                raster_info['meta'], str):
            raise ValueError(
                "global configuration requested metadata mapping onto loaded "
                "samples, but raster did not have available metadata")
        yaml_metadata = read_parameters(raster_info['meta'])
        metadata_dict.update(yaml_metadata)
    return metadata_dict
예제 #2
0
                     100, 1) if pixel_total > 0 else 0
        print('Pixels from class', i, ':', prop, '%')

    print("Number of samples created: ", number_samples)

    if bucket_name and final_samples_folder:
        print('Transfering Samples to the bucket')
        bucket.upload_file(samples_folder + "/trn_samples.hdf5",
                           final_samples_folder + '/trn_samples.hdf5')
        bucket.upload_file(samples_folder + "/val_samples.hdf5",
                           final_samples_folder + '/val_samples.hdf5')
        bucket.upload_file(samples_folder + "/tst_samples.hdf5",
                           final_samples_folder + '/tst_samples.hdf5')

    print("End of process")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Sample preparation')
    parser.add_argument('ParamFile',
                        metavar='DIR',
                        help='Path to training parameters stored in yaml')
    args = parser.parse_args()
    params = read_parameters(args.ParamFile)
    start_time = time.time()
    tqdm.write(
        f'\n\nStarting images to samples preparation with {args.ParamFile}\n\n'
    )
    main(params)
    print("Elapsed time:{}".format(time.time() - start_time))
예제 #3
0
                    inputs = data['sat_img'].to(device)
                    labels = data['map_img'].to(device)

                    outputs = model(inputs)
                    if isinstance(outputs, OrderedDict):
                        outputs = outputs['out']

                    vis_from_batch(params,
                                   inputs,
                                   outputs,
                                   batch_index=batch_index,
                                   vis_path=vis_path,
                                   labels=labels,
                                   dataset=dataset,
                                   ep_num=ep_num)
    tqdm.write(f'Saved visualization figures.\n')


if __name__ == '__main__':
    print(f'Start\n')
    parser = argparse.ArgumentParser(description='Training execution')
    parser.add_argument('param_file',
                        metavar='DIR',
                        help='Path to training parameters stored in yaml')
    args = parser.parse_args()
    config_path = Path(args.param_file)
    params = read_parameters(args.param_file)

    main(params, config_path)
    print('End of training')
예제 #4
0
def main(params):
    """
    Identify the class to which each image belongs.
    :param params: (dict) Parameters found in the yaml config file.

    """
    # SET BASIC VARIABLES AND PATHS
    since = time.time()

    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(f'Debug mode activated. Some debug features may mobilize extra disk space and cause delays in execution.')

    num_classes = params['global']['num_classes']
    if params['global']['task'] == 'segmentation':
        # assume background is implicitly needed (makes no sense to predict with one class, for example.)
        # this will trigger some warnings elsewhere, but should succeed nonetheless
        num_classes_corrected = num_classes + 1 # + 1 for background # FIXME temporary patch for num_classes problem.
    elif params['global']['task'] == 'classification':
        num_classes_corrected = num_classes

    chunk_size = get_key_def('chunk_size', params['inference'], 512)
    overlap = get_key_def('overlap', params['inference'], 10)
    nbr_pix_overlap = int(math.floor(overlap / 100 * chunk_size))
    num_bands = params['global']['number_of_bands']

    img_dir_or_csv = params['inference']['img_dir_or_csv_file']

    default_working_folder = Path(params['inference']['state_dict_path']).parent.joinpath(f'inference_{num_bands}bands')
    working_folder = Path(get_key_def('working_folder', params['inference'], default_working_folder)) # TODO: remove working_folder parameter in all templates
    Path.mkdir(working_folder, exist_ok=True)
    print(f'Inferences will be saved to: {working_folder}\n\n')

    bucket = None
    bucket_file_cache = []
    bucket_name = params['global']['bucket_name']

    # CONFIGURE MODEL
    model, state_dict_path, model_name = net(params, num_channels=num_classes_corrected, inference=True)

    num_devices = params['global']['num_gpus'] if params['global']['num_gpus'] else 0
    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    lst_device_ids = get_device_ids(num_devices) if torch.cuda.is_available() else []
    device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.is_available() and lst_device_ids else 'cpu')

    if lst_device_ids:
        print(f"Number of cuda devices requested: {num_devices}. Cuda devices available: {lst_device_ids}. Using {lst_device_ids[0]}\n\n")
    else:
        warnings.warn(f"No Cuda device available. This process will only run on CPU")

    try:
        model.to(device)
    except RuntimeError:
        print(f"Unable to use device. Trying device 0")
        device = torch.device(f'cuda:0' if torch.cuda.is_available() and lst_device_ids else 'cpu')
        model.to(device)

    if bucket_name:
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(bucket_name)
        if img_dir_or_csv.endswith('.csv'):
            bucket.download_file(img_dir_or_csv, 'img_csv_file.csv')
            list_img = read_csv('img_csv_file.csv', inference=True)
        else:
            raise NotImplementedError(
                'Specify a csv file containing images for inference. Directory input not implemented yet')
    else:
        if img_dir_or_csv.endswith('.csv'):
            list_img = read_csv(img_dir_or_csv, inference=True)
        else:
            img_dir = Path(img_dir_or_csv)
            assert img_dir.is_dir(), f'Could not find directory "{img_dir_or_csv}"'
            list_img_paths = sorted(img_dir.glob('*.tif'))  # FIXME: what if .tif is in caps (.TIF) ?
            list_img = []
            for img_path in list_img_paths:
                img = {}
                img['tif'] = img_path
                list_img.append(img)
            assert len(list_img) >= 0, f'No .tif files found in {img_dir_or_csv}'

    if params['global']['task'] == 'classification':
        classifier(params, list_img, model, device, working_folder)  # FIXME: why don't we load from checkpoint in classification?

    elif params['global']['task'] == 'segmentation':
        if bucket:
            bucket.download_file(state_dict_path, "saved_model.pth.tar")
            model, _ = load_from_checkpoint("saved_model.pth.tar", model, inference=True)
        else:
            model, _ = load_from_checkpoint(state_dict_path, model, inference=True)

        with tqdm(list_img, desc='image list', position=0) as _tqdm:
            for img in _tqdm:
                img_name = Path(img['tif']).name
                if bucket:
                    local_img = f"Images/{img_name}"
                    bucket.download_file(img['tif'], local_img)
                    inference_image = f"Classified_Images/{img_name.split('.')[0]}_inference.tif"
                    if img['meta']:
                        if img['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(img['meta'])
                            bucket.download_file(img['meta'], img['meta'].split('/')[-1])
                        img['meta'] = img['meta'].split('/')[-1]
                else:
                    local_img = Path(img['tif'])
                    inference_image = working_folder.joinpath(f"{img_name.split('.')[0]}_inference.tif")

                assert local_img.is_file(), f"Could not open raster file at {local_img}"

                scale = get_key_def('scale_data', params['global'], None)
                with rasterio.open(local_img, 'r') as raster:

                    np_input_image = image_reader_as_array(input_image=raster,
                                                           scale=scale,
                                                           aux_vector_file=get_key_def('aux_vector_file',
                                                                                       params['global'], None),
                                                           aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                                                         params['global'], None),
                                                           aux_vector_ids=get_key_def('aux_vector_ids',
                                                                                      params['global'], None),
                                                           aux_vector_dist_maps=get_key_def('aux_vector_dist_maps',
                                                                                            params['global'], True),
                                                           aux_vector_scale=get_key_def('aux_vector_scale',
                                                                                        params['global'], None))

                meta_map, metadata = get_key_def("meta_map", params["global"], {}), None
                if meta_map:
                    assert img['meta'] is not None and isinstance(img['meta'], str) and os.path.isfile(img['meta']), \
                        "global configuration requested metadata mapping onto loaded samples, but raster did not have available metadata"
                    metadata = read_parameters(img['meta'])

                if debug:
                    _tqdm.set_postfix(OrderedDict(img_name=img_name,
                                                  img=np_input_image.shape,
                                                  img_min_val=np.min(np_input_image),
                                                  img_max_val=np.max(np_input_image)))

                input_band_count = np_input_image.shape[2] + MetaSegmentationDataset.get_meta_layer_count(meta_map)
                if input_band_count > params['global']['number_of_bands']:
                    # FIXME: Following statements should be reconsidered to better manage inconsistencies between
                    #  provided number of band and image number of band.
                    warnings.warn(f"Input image has more band than the number provided in the yaml file ({params['global']['number_of_bands']}). "
                                  f"Will use the first {params['global']['number_of_bands']} bands of the input image.")
                    np_input_image = np_input_image[:, :, 0:params['global']['number_of_bands']]
                    print(f"Input image's new shape: {np_input_image.shape}")

                elif input_band_count < params['global']['number_of_bands']:
                    warnings.warn(f"Skipping image: The number of bands requested in the yaml file ({params['global']['number_of_bands']})"
                                  f"can not be larger than the number of band in the input image ({input_band_count}).")
                    continue

                # START INFERENCES ON SUB-IMAGES
                sem_seg_results_per_class = sem_seg_inference(model, np_input_image, nbr_pix_overlap, chunk_size, num_classes_corrected,
                                                    device, meta_map, metadata, output_path=working_folder, index=_tqdm.n, debug=debug)

                # CREATE GEOTIF FROM METADATA OF ORIGINAL IMAGE
                tqdm.write(f'Saving inference...\n')
                if get_key_def('heatmaps', params['inference'], False):
                    tqdm.write(f'Heatmaps will be saved.\n')
                vis(params, np_input_image, sem_seg_results_per_class, working_folder, inference_input_path=local_img, debug=debug)

                tqdm.write(f"\n\nSemantic segmentation of image {img_name} completed\n\n")
                if bucket:
                    bucket.upload_file(inference_image, os.path.join(working_folder, f"{img_name.split('.')[0]}_inference.tif"))
    else:
        raise ValueError(
            f"The task should be either classification or segmentation. The provided value is {params['global']['task']}")

    time_elapsed = time.time() - since
    print('Inference completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
예제 #5
0
    parser.add_argument('-p',
                        '--param',
                        metavar='yaml_file',
                        nargs=1,
                        help='Path to parameters stored in yaml')
    parser.add_argument('-i',
                        '--input',
                        metavar='model_pth img_dir',
                        nargs=2,
                        help='model_path and image_dir')
    args = parser.parse_args()

    # if a yaml is inputted, get those parameters and get model state_dict to overwrite global parameters afterwards
    if args.param:
        input_params = read_parameters(args.param[0])
        model_ckpt = get_key_def('state_dict_path',
                                 input_params['inference'],
                                 expected_type=str)
        # load checkpoint
        checkpoint = load_checkpoint(model_ckpt)
        if 'params' in checkpoint.keys():
            params = checkpoint['params']
            # overwrite with inputted parameters
            compare_config_yamls(yaml1=params,
                                 yaml2=input_params,
                                 update_yaml1=True)
        else:
            warnings.warn(
                'No parameters found in checkpoint. Defaulting to parameters from inputted yaml.'
                'Use GDL version 1.3 or more.')
def main(params):
    """
    Training and validation datasets preparation.
    :param params: (dict) Parameters found in the yaml config file.

    """
    now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    bucket_file_cache = []

    assert params['global'][
        'task'] == 'segmentation', f"images_to_samples.py isn't necessary when performing classification tasks"

    # SET BASIC VARIABLES AND PATHS. CREATE OUTPUT FOLDERS.
    bucket_name = params['global']['bucket_name']
    data_path = Path(params['global']['data_path'])
    Path.mkdir(data_path, exist_ok=True, parents=True)
    csv_file = params['sample']['prep_csv_file']
    val_percent = params['sample']['val_percent']
    samples_size = params["global"]["samples_size"]
    overlap = params["sample"]["overlap"]
    min_annot_perc = params['sample']['sampling']['map']
    num_bands = params['global']['number_of_bands']
    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(f'Debug mode activate. Execution may take longer...')

    final_samples_folder = None
    if bucket_name:
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(bucket_name)
        bucket.download_file(csv_file, 'samples_prep.csv')
        list_data_prep = read_csv('samples_prep.csv')
        if data_path:
            final_samples_folder = os.path.join(data_path, "samples")
        else:
            final_samples_folder = "samples"
        samples_folder = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'  # TODO: validate this is preferred name structure

    else:
        list_data_prep = read_csv(csv_file)
        samples_folder = data_path.joinpath(
            f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'
        )

    if samples_folder.is_dir():
        warnings.warn(
            f'Data path exists: {samples_folder}. Suffix will be added to directory name.'
        )
        samples_folder = Path(str(samples_folder) + '_' + now)
    else:
        tqdm.write(f'Writing samples to {samples_folder}')
    Path.mkdir(samples_folder, exist_ok=False
               )  # FIXME: what if we want to append samples to existing hdf5?
    tqdm.write(f'Samples will be written to {samples_folder}\n\n')

    tqdm.write(
        f'\nSuccessfully read csv file: {Path(csv_file).stem}\nNumber of rows: {len(list_data_prep)}\nCopying first entry:\n{list_data_prep[0]}\n'
    )
    ignore_index = get_key_def('ignore_index', params['training'], -1)

    for info in tqdm(list_data_prep,
                     position=0,
                     desc=f'Asserting existence of tif and gpkg files in csv'):
        assert Path(info['tif']).is_file(
        ), f'Could not locate "{info["tif"]}". Make sure file exists in this directory.'
        assert Path(info['gpkg']).is_file(
        ), f'Could not locate "{info["gpkg"]}". Make sure file exists in this directory.'
    if debug:
        for info in tqdm(
                list_data_prep,
                position=0,
                desc=f"Validating presence of {params['global']['num_classes']} "
                f"classes in attribute \"{info['attribute_name']}\" for vector "
                f"file \"{Path(info['gpkg']).stem}\""):
            validate_num_classes(info['gpkg'], params['global']['num_classes'],
                                 info['attribute_name'], ignore_index)
        with tqdm(list_data_prep,
                  position=0,
                  desc=f"Checking validity of features in vector files"
                  ) as _tqdm:
            invalid_features = {}
            for info in _tqdm:
                # Extract vector features to burn in the raster image
                with fiona.open(
                        info['gpkg'],
                        'r') as src:  # TODO: refactor as independent function
                    lst_vector = [vector for vector in src]
                shapes = lst_ids(list_vector=lst_vector,
                                 attr_name=info['attribute_name'])
                for index, item in enumerate(
                        tqdm([v for vecs in shapes.values() for v in vecs],
                             leave=False,
                             position=1)):
                    # geom must be a valid GeoJSON geometry type and non-empty
                    geom, value = item
                    geom = getattr(geom, '__geo_interface__', None) or geom
                    if not is_valid_geom(geom):
                        gpkg_stem = str(Path(info['gpkg']).stem)
                        if gpkg_stem not in invalid_features.keys(
                        ):  # create key with name of gpkg
                            invalid_features[gpkg_stem] = []
                        if lst_vector[index]["id"] not in invalid_features[
                                gpkg_stem]:  # ignore feature is already appended
                            invalid_features[gpkg_stem].append(
                                lst_vector[index]["id"])
            assert len(
                invalid_features.values()
            ) == 0, f'Invalid geometry object(s) for "gpkg:ids": \"{invalid_features}\"'

    number_samples = {'trn': 0, 'val': 0, 'tst': 0}
    number_classes = 0

    # 'sampling' ordereddict validation
    check_sampling_dict()

    pixel_classes = {}
    # creates pixel_classes dict and keys
    for i in range(0, params['global']['num_classes'] + 1):
        pixel_classes.update({i: 0})
    pixel_classes.update(
        {ignore_index: 0}
    )  # FIXME: pixel_classes dict needs to be populated with classes obtained from target

    trn_hdf5, val_hdf5, tst_hdf5 = create_files_and_datasets(
        params, samples_folder)

    # For each row in csv: (1) burn vector file to raster, (2) read input raster image, (3) prepare samples
    with tqdm(list_data_prep,
              position=0,
              leave=False,
              desc=f'Preparing samples') as _tqdm:
        for info in _tqdm:
            _tqdm.set_postfix(
                OrderedDict(tif=f'{Path(info["tif"]).stem}',
                            sample_size=params['global']['samples_size']))
            try:
                if bucket_name:
                    bucket.download_file(
                        info['tif'], "Images/" + info['tif'].split('/')[-1])
                    info['tif'] = "Images/" + info['tif'].split('/')[-1]
                    if info['gpkg'] not in bucket_file_cache:
                        bucket_file_cache.append(info['gpkg'])
                        bucket.download_file(info['gpkg'],
                                             info['gpkg'].split('/')[-1])
                    info['gpkg'] = info['gpkg'].split('/')[-1]
                    if info['meta']:
                        if info['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(info['meta'])
                            bucket.download_file(info['meta'],
                                                 info['meta'].split('/')[-1])
                        info['meta'] = info['meta'].split('/')[-1]

                with rasterio.open(info['tif'], 'r') as raster:
                    # Burn vector file in a raster file
                    np_label_raster = vector_to_raster(
                        vector_file=info['gpkg'],
                        input_image=raster,
                        attribute_name=info['attribute_name'],
                        fill=get_key_def('ignore_idx',
                                         get_key_def('training', params, {}),
                                         0))
                    # Read the input raster image
                    np_input_image = image_reader_as_array(
                        input_image=raster,
                        scale=get_key_def('scale_data', params['global'],
                                          None),
                        aux_vector_file=get_key_def('aux_vector_file',
                                                    params['global'], None),
                        aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                      params['global'], None),
                        aux_vector_ids=get_key_def('aux_vector_ids',
                                                   params['global'], None),
                        aux_vector_dist_maps=get_key_def(
                            'aux_vector_dist_maps', params['global'], True),
                        aux_vector_dist_log=get_key_def(
                            'aux_vector_dist_log', params['global'], True),
                        aux_vector_scale=get_key_def('aux_vector_scale',
                                                     params['global'], None))

                # Mask the zeros from input image into label raster.
                if params['sample']['mask_reference']:
                    np_label_raster = mask_image(np_input_image,
                                                 np_label_raster)

                if info['dataset'] == 'trn':
                    out_file = trn_hdf5
                    val_file = val_hdf5
                elif info['dataset'] == 'tst':
                    out_file = tst_hdf5
                else:
                    raise ValueError(
                        f"Dataset value must be trn or val or tst. Provided value is {info['dataset']}"
                    )

                meta_map, metadata = get_key_def("meta_map", params["global"],
                                                 {}), None
                if info['meta'] is not None and isinstance(
                        info['meta'], str) and Path(info['meta']).is_file():
                    metadata = read_parameters(info['meta'])

                # FIXME: think this through. User will have to calculate the total number of bands including meta layers and
                #  specify it in yaml. Is this the best approach? What if metalayers are added on the fly ?
                input_band_count = np_input_image.shape[
                    2] + MetaSegmentationDataset.get_meta_layer_count(meta_map)
                # FIXME: could this assert be done before getting into this big for loop?
                assert input_band_count == num_bands, \
                    f"The number of bands in the input image ({input_band_count}) and the parameter" \
                    f"'number_of_bands' in the yaml file ({params['global']['number_of_bands']}) should be identical"

                np_label_raster = np.reshape(
                    np_label_raster,
                    (np_label_raster.shape[0], np_label_raster.shape[1], 1))
                number_samples, number_classes = samples_preparation(
                    np_input_image, np_label_raster, samples_size, overlap,
                    number_samples, number_classes, out_file, val_percent,
                    val_file, info['dataset'], pixel_classes, metadata)

                _tqdm.set_postfix(OrderedDict(number_samples=number_samples))
                out_file.flush()
            except Exception as e:
                warnings.warn(
                    f'An error occurred while preparing samples with "{Path(info["tif"]).stem}" (tiff) and '
                    f'{Path(info["gpkg"]).stem} (gpkg). Error: "{e}"')
                continue

    trn_hdf5.close()
    val_hdf5.close()
    tst_hdf5.close()

    pixel_total = 0
    # adds up the number of pixels for each class in pixel_classes dict
    for i in pixel_classes:
        pixel_total += pixel_classes[i]

    # prints the proportion of pixels of each class for the samples created
    for i in pixel_classes:
        print('Pixels from class', i, ':',
              round((pixel_classes[i] / pixel_total) * 100, 1), '%')

    print("Number of samples created: ", number_samples)

    if bucket_name and final_samples_folder:
        print('Transfering Samples to the bucket')
        bucket.upload_file(samples_folder + "/trn_samples.hdf5",
                           final_samples_folder + '/trn_samples.hdf5')
        bucket.upload_file(samples_folder + "/val_samples.hdf5",
                           final_samples_folder + '/val_samples.hdf5')
        bucket.upload_file(samples_folder + "/tst_samples.hdf5",
                           final_samples_folder + '/tst_samples.hdf5')

    print("End of process")