Example #1
0
def test_predict_neighbors(data, metadata, mod):
    target = data.iloc[0]
    neighbor_pool = data
    raster = rasterio.open(test_sensor_tile)
    feature_array, distances = neighbors.predict_neighbors(target, metadata=metadata, HSI_size=20, raster=raster, neighbor_pool=neighbor_pool, model=mod.ensemble_model, k_neighbors=2)
    
    assert feature_array.shape[0] == 2
    assert feature_array.shape[1] == mod.ensemble_model.output.shape[1]

    assert len(distances) == 2
Example #2
0
def test_predict_neighbors(data, metadata, mod):
    target = data.iloc[0]
    neighbor_pool = data[~(data.index == target.index)]
    raster = rasterio.open(test_sensor_tile)
    feature_array, distances = neighbors.predict_neighbors(
        target,
        metadata=metadata,
        HSI_size=20,
        raster=raster,
        neighbor_pool=neighbor_pool,
        model=mod.ensemble_model,
        k_neighbors=5)
    assert feature_array.shape[0] == 5
    assert feature_array.shape[1] == mod.ensemble_model.get_layer(
        "submodel_concat").output.shape[1]

    assert len(distances) == 5
Example #3
0
def generate_tfrecords(HSI_sensor_path,
                       RGB_sensor_path,
                       domain,
                       site,
                       number_of_sites,
                       number_of_domains,
                       elevation,
                       species_label_dict,
                       chunk_size=1000,
                       savedir=".",
                       HSI_size=20,
                       RGB_size=100,
                       classes=20,
                       train=True,
                       extend_HSI_box=0,
                       extend_RGB_box=0,
                       shuffle=True,
                       shapefile=None,
                       csv_file=None,
                       label_column="label",
                       ensemble_model=None,
                       k_neighbors=5,
                       raw_boxes=None):
    """Yield one instance of data with one hot labels
    Args:
        chunk_size: number of windows per tfrecord
        savedir: directory to save tfrecords
        domain: metadata site domain as integer
        site: metadata site label as integer
        elevation: height above sea level in meters
        label_dict: taxonID -> numeric label
        RGB_size: size in pixels of one side of image
        HSI_size: size in pixels of one side of image
        train: training mode to include yielded labels
        number_of_sites: total number of sites used for one-hot encoding
        extend_HSI_box: units in meters to expand DeepForest bounding box to give crop more context
        extend_RGB_box: units in meters to expand DeepForest bounding box to give crop more context
        include_neighbors: logical, whether to extract HSI data from neighbor trees.
        ensemble_model: an ensemble model that predicts neighbor features
        k_neighbors: number of neighbors to extract
        raw_boxes: a geodataframe of boxes to choose for neighbor analysis
    Returns:
        filename: tfrecords path
    """

    if all([x is None for x in [csv_file, shapefile]]):
        raise AttributeError("Either pass a shapefile=, or csv_file argument")

    HSI_src = rasterio.open(HSI_sensor_path)
    RGB_src = rasterio.open(RGB_sensor_path)

    #Read csv file
    if shapefile is None:
        basename = os.path.splitext(os.path.basename(csv_file))[0]
        gdf = pd.read_csv(csv_file)
        gdf['geometry'] = gdf['geometry'].apply(wkt.loads)
        gdf = gpd.GeoDataFrame(gdf)

        #assign crs
        gdf.crs = RGB_src.crs

    else:
        basename = os.path.splitext(os.path.basename(shapefile))[0]
        gdf = gpd.read_file(shapefile)

    #Remove any nan and species not in the label dict if provided
    gdf = gdf[~gdf[label_column].isnull()]
    if species_label_dict is not None:
        gdf = gdf[gdf[label_column].isin(list(species_label_dict.keys()))]

    gdf["box_id"] = gdf.index.values
    labels = []
    HSI_crops = []
    RGB_crops = []
    indices = []
    neighbor_arrays = []
    neighbor_distances = []

    #Give an individual column
    gdf["individual"] = gdf.index.values

    for index, row in gdf.iterrows():
        #Add training label, ignore unclassified 0 class
        if train:
            labels.append(row[label_column])
        try:
            HSI_crop = crop_image(HSI_src, row["geometry"], extend_HSI_box)
            RGB_crop = crop_image(RGB_src, row["geometry"], extend_RGB_box)
        except Exception as e:
            print("row {} failed with {}".format(index, e))
            continue

        HSI_crops.append(HSI_crop)
        RGB_crops.append(RGB_crop)
        indices.append(int(row["point_id"]))

        #extract neighbors
        if ensemble_model is not None:
            neighbor_pool = gpd.read_file(raw_boxes)

            one_hot_sites = tf.one_hot(site, number_of_sites)
            one_hot_domains = tf.one_hot(domain, number_of_domains)
            metadata = [elevation, one_hot_sites, one_hot_domains]

            raster = rasterio.open(HSI_sensor_path)
            neighbor_array, neighbor_distance = neighbors.predict_neighbors(
                row,
                metadata=metadata,
                HSI_size=HSI_size,
                raster=raster,
                neighbor_pool=neighbor_pool,
                model=ensemble_model,
                k_neighbors=k_neighbors)
            neighbor_arrays.append(neighbor_array.astype("float32"))
            neighbor_distances.append(neighbor_distance)

        else:
            neighbor_arrays.append(None)
            neighbor_distances.append(None)

    #If passes a species label dict
    if species_label_dict is None:
        #Create and save a new species and site label dict
        unique_species_labels = np.unique(labels)
        species_label_dict = {}
        for index, label in enumerate(unique_species_labels):
            species_label_dict[label] = index
        pd.DataFrame(species_label_dict.items(),
                     columns=["taxonID", "label"]).to_csv(
                         "{}/species_class_labels.csv".format(savedir))

    numeric_species_labels = [species_label_dict[x] for x in labels]

    #shuffle before writing to help with validation data split
    if shuffle:
        print("Shuffling")
        if train:
            z = list(
                zip(HSI_crops, RGB_crops, indices, numeric_species_labels,
                    neighbor_arrays, neighbor_distances))
            random.shuffle(z)
            HSI_crops, RGB_crops, indices, numeric_species_labels, neighbor_arrays, neighbor_distances = zip(
                *z)

    #get keys and divide into chunks for a single tfrecord
    filenames = []
    counter = 0
    for i in range(0, len(HSI_crops) + 1, chunk_size):
        chunk_HSI_crops = HSI_crops[i:i + chunk_size]
        chunk_RGB_crops = RGB_crops[i:i + chunk_size]
        chunk_index = indices[i:i + chunk_size]

        #if neighbors
        chunk_neighbor_arrays = neighbor_arrays[i:i + chunk_size]
        chunk_neighbor_distances = neighbor_distances[i:i + chunk_size]

        #All records in a single shapefile are the same site
        chunk_sites = np.repeat(site, len(chunk_index))
        chunk_domains = np.repeat(domain, len(chunk_index))
        chunk_elevations = np.repeat(elevation, len(chunk_index))

        if train:
            chunk_labels = numeric_species_labels[i:i + chunk_size]
        else:
            chunk_labels = None

        #resize crops and ensure dtypes
        resized_HSI_crops = [
            resize(x, HSI_size, HSI_size).astype(np.float32)
            for x in chunk_HSI_crops
        ]
        resized_RGB_crops = [
            resize(x, RGB_size, RGB_size).astype(np.float32)
            for x in chunk_RGB_crops
        ]

        resized_HSI_crops = [image_normalize(x) for x in resized_HSI_crops]

        filename = "{}/{}_{}.tfrecord".format(savedir, basename, counter)

        write_tfrecord(filename=filename,
                       HSI_images=resized_HSI_crops,
                       RGB_images=resized_RGB_crops,
                       labels=chunk_labels,
                       domains=chunk_domains,
                       sites=chunk_sites,
                       elevations=chunk_elevations,
                       indices=chunk_index,
                       neighbor_arrays=chunk_neighbor_arrays,
                       neighbor_distances=chunk_neighbor_distances,
                       number_of_sites=number_of_sites,
                       number_of_domains=number_of_domains,
                       classes=classes)

        filenames.append(filename)
        counter += 1

    return filenames