Пример #1
0
def main(dirname, generate=True, cpus=2):
    """Create a dask cluster and run list of shapefiles in parallel
        Args:
            dirname: directory of DeepForest predicted shapefiles to run
            generate: Do tfrecords need to be generated/overwritten or use existing records?
            cpus: Number of dask cpus to run
    """
    shapefiles = find_shapefiles(dirname=dirname)
    client = start_cluster.start(cpus=cpus)
    futures = client.map(run, shapefiles)
    wait(futures)

    for future in futures:
        print(future.result())
Пример #2
0
def predict(dirname, savedir, generate=True, cpus=2, parallel=True, height=40, width=40, channels=3):
    """Create a wrapper dask cluster and run list of shapefiles in parallel (optional)
        Args:
            dirname: directory of DeepForest predicted shapefiles to run
            savedir: directory to write processed shapefiles
            generate: Do tfrecords need to be generated/overwritten or use existing records?
            cpus: Number of dask cpus to run
    """
    shapefiles = find_shapefiles(dirname=dirname)
    
    if parallel:
        client = start_cluster.start(cpus=cpus)
        futures = client.map(_predict_,shapefiles, create_records=generate, savedir=savedir, height=height, width=width, channels=channels)
        wait(futures)
        
        for future in futures:
            print(future.result())
    else:
        for shapefile in shapefiles:
            _predict_(shapefile, model_path, savedir=savedir, create_records=generate)
Пример #3
0
from DeepTreeAttention.trees import AttentionModel, __file__
from DeepTreeAttention.generators import boxes
from DeepTreeAttention.utils.start_cluster import start
from DeepTreeAttention.utils.paths import *

from distributed import wait

#Delete any file previous run
old_files = glob.glob("/orange/idtrees-collab/DeepTreeAttention/WeakLabels/*")
[os.remove(x) for x in old_files]
old_files = glob.glob("/orange/idtrees-collab/DeepTreeAttention/tfrecords/pretraining/*")
[os.remove(x) for x in old_files]


#get root dir full path
client = start(cpus=60, mem_size="10GB") 

weak_records = glob.glob(os.path.join("/orange/idtrees-collab/species_classification/confident_predictions","*.csv"))

#Check if complete
def check_shape(x):
    df = pd.read_csv(x)
    if len(df.columns) == 12:
        return x
    else:
        return None
    
futures = client.map(check_shape,weak_records)
completed_records = [x.result() for x in futures if x.result() is not None]

#Create a dask dataframe of csv files
def main(
    field_data,
    RGB_size,
    HSI_size,
    rgb_dir, 
    hyperspectral_dir,
    savedir=".", 
    chunk_size=400,
    extend_box=0, 
    hyperspectral_savedir=".", 
    n_workers=20,
    saved_model=None, 
    use_dask=False, 
    shuffle=True,
    species_classes_file=None,
    site_classes_file=None):
    """Prepare NEON field data into tfrecords
    Args:
        field_data: shp file with location and class of each field collected point
        height: height in meters of the resized training image
        width: width in meters of the resized training image
        savedir: direcory to save completed tfrecords
        extend_box: units in meters to add to the edge of a predicted box to give more context
        hyperspectral_savedir: location to save converted .h5 to .tif
        n_workers: number of dask workers
        species_classes_file: optional path to a two column csv file with index and species labels
        site_classes_file: optional path to a two column csv file with index and site labels
        shuffle: shuffle lists before writing
    Returns:
        tfrecords: list of created tfrecords
    """ 
    df = gpd.read_file(field_data)
    plot_names = df.plotID.unique()
    
    hyperspectral_pool = glob.glob(hyperspectral_dir, recursive=True)
    rgb_pool = glob.glob(rgb_dir, recursive=True)
    
    labels = []
    HSI_crops = []
    RGB_crops = []
    sites = []
    box_indexes = []    
    elevations = []
    if use_dask:
        client = start_cluster.start(cpus=n_workers, mem_size="10GB")
        futures = []
        for plot in plot_names:
            future = client.submit(
                run,
                plot=plot,
                df=df,
                rgb_pool=rgb_pool,
                hyperspectral_pool=hyperspectral_pool,
                extend_box=extend_box,
                hyperspectral_savedir=hyperspectral_savedir,
                saved_model=saved_model
            )
            futures.append(future)
        
        wait(futures)
        for x in futures:
            try:
                plot_HSI_crops, plot_RGB_crops, plot_labels, plot_sites, plot_elevations, plot_box_index = x.result()
                
                #Append to general plot list
                HSI_crops.extend(plot_HSI_crops)
                RGB_crops.extend(plot_RGB_crops)
                labels.extend(plot_labels)
                sites.extend(plot_sites)            
                elevations.extend(plot_elevations)
                box_indexes.extend(plot_box_index)        
            except Exception as e:
                print("Future failed with {}".format(e))      
                traceback.print_exc()
    else:
        from deepforest import deepforest        
        deepforest_model = deepforest.deepforest()
        deepforest_model.use_release()        
        for plot in plot_names:
            try:
                plot_HSI_crops, plot_RGB_crops, plot_labels, plot_sites, plot_elevations, plot_box_index = run(
                    plot=plot,
                    df=df,
                    rgb_pool=rgb_pool,
                    hyperspectral_pool=hyperspectral_pool, 
                    extend_box=extend_box,
                    hyperspectral_savedir=hyperspectral_savedir,
                    saved_model=saved_model,
                    deepforest_model=deepforest_model
                )
            except Exception as e:
                print("Plot failed with {}".format(e))      
                traceback.print_exc()  
                continue
    
            #Append to general plot list
            HSI_crops.extend(plot_HSI_crops)
            RGB_crops.extend(plot_RGB_crops)
            labels.extend(plot_labels)
            sites.extend(plot_sites)            
            elevations.extend(plot_elevations)
            box_indexes.extend(plot_box_index)
            


    if shuffle:
        z = list(zip(HSI_crops, RGB_crops, sites, elevations, box_indexes, labels))
        random.shuffle(z)
        HSI_crops, RGB_crops, sites, elevations, box_indexes, labels = zip(*z)
                        
    #If passes a species label dict
    if species_classes_file is not None:
        species_classdf  = pd.read_csv(species_classes_file)
        species_label_dict = species_classdf.set_index("taxonID").label.to_dict()
    else:
        #Create and save a new species and site label dict
        unique_species_labels = np.unique(labels)
        species_label_dict = {}
        
        for index, label in enumerate(unique_species_labels):
            species_label_dict[label] = index
        pd.DataFrame(species_label_dict.items(), columns=["taxonID","label"]).to_csv("{}/species_class_labels.csv".format(savedir))

    #If passes a site label dict
    if site_classes_file is not None:
        site_classdf  = pd.read_csv(site_classes_file)
        site_label_dict = site_classdf.set_index("siteID").label.to_dict()
    else:
        #Create and save a new site and site label dict
        unique_site_labels = np.unique(sites)
        site_label_dict = {}
        
        for index, label in enumerate(unique_site_labels):
            site_label_dict[label] = index
        pd.DataFrame(site_label_dict.items(), columns=["siteID","label"]).to_csv("{}/site_class_labels.csv".format(savedir))

    #Convert labels to numeric
    numeric_labels = [species_label_dict[x] for x in labels]
    numeric_sites = [site_label_dict[x] for x in sites]
    
    #Write tfrecords
    tfrecords = create_records(
        HSI_crops=HSI_crops,
        RGB_crops=RGB_crops,
        labels=numeric_labels, 
        sites=numeric_sites, 
        elevations=elevations,
        box_index=box_indexes, 
        savedir=savedir, 
        RGB_size=RGB_size,
        HSI_size=HSI_size, 
        chunk_size=chunk_size)
    
    return tfrecords
    return tfrecords


if __name__ == "__main__":
    #Generate the training data shapefiles
    ROOT = os.path.dirname(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

    lookup_glob = "/orange/ewhite/NeonData/**/CanopyHeightModelGtif/*.tif"

    #Read config from top level dir
    config = parse_yaml("{}/conf/tree_config.yml".format(ROOT))

    #create dask client
    client = start_cluster.start(cpus=config["cpu_workers"], mem_size="10GB")
    #client = None

    #Create train test split
    create_training_shp.train_test_split(
        ROOT,
        lookup_glob,
        n=config["train"]["resampled_per_taxa"],
        client=client,
        regenerate=False)

    #test data
    main(
        field_data=config["evaluation"]["ground_truth_path"],
        RGB_size=config["train"]["RGB"]["crop_size"],
        HSI_size=config["train"]["HSI"]["crop_size"],
Пример #6
0
#Generate tfrecords
from DeepTreeAttention.trees import AttentionModel
from DeepTreeAttention.generators import boxes
from DeepTreeAttention.utils.start_cluster import start
from DeepTreeAttention.utils.paths import lookup_and_convert

from distributed import wait
import glob
import os

att = AttentionModel(config="/home/b.weinstein/DeepTreeAttention/conf/tree_config.yml")

#get root dir full path
client = start(cpus=10, mem_size="5GB") 

#Generate training data
train_tfrecords = []
shapefiles = glob.glob(os.path.join("/orange/idtrees-collab/DeepTreeAttention/WeakLabels/","*.shp"))
for shapefile in shapefiles:
    sensor_path = lookup_and_convert(shapefile, rgb_pool=att.config["train"]["rgb_sensor_pool"], hyperspectral_pool=att.config["train"]["hyperspectral_sensor_pool"], savedir=att.config["hyperspectral_tif_dir"])
    future = client.submit(att.generate, shapefile=shapefile, sensor_path=sensor_path, chunk_size=10000, train=True)
    train_tfrecords.append(future)
    
wait(train_tfrecords)
for x in train_tfrecords:
    x.result()
        
Пример #7
0
def main(field_data,
         height,
         width,
         rgb_pool=None,
         hyperspectral_pool=None,
         sensor="hyperspectral",
         savedir=".",
         chunk_size=200,
         extend_box=0,
         hyperspectral_savedir=".",
         n_workers=20,
         saved_model=None,
         use_dask=True,
         shuffle=True):
    """Prepare NEON field data into tfrecords
    Args:
        field_data: shp file with location and class of each field collected point
        height: height in meters of the resized training image
        width: width in meters of the resized training image
        sensor: 'rgb' or 'hyperspecral' image crop
        savedir: direcory to save completed tfrecords
        extend_box: units in meters to add to the edge of a predicted box to give more context
        hyperspectral_savedir: location to save converted .h5 to .tif
        n_workers: number of dask workers
        shuffle: shuffle lists before writing
    Returns:
        tfrecords: list of created tfrecords
    """
    #Check sensor type has paths
    if sensor == "hyperspectral":
        assert not hyperspectral_pool is None
    if sensor == "rgb":
        assert not rgb_pool is None

    df = gpd.read_file(field_data)
    plot_names = df.plotID.unique()

    labels = []
    crops = []
    box_indexes = []
    if use_dask:
        client = start_cluster.start(cpus=n_workers, mem_size="20GB")
        futures = []
        for plot in plot_names:
            future = client.submit(run,
                                   plot=plot,
                                   df=df,
                                   rgb_pool=rgb_pool,
                                   hyperspectral_pool=hyperspectral_pool,
                                   sensor=sensor,
                                   extend_box=extend_box,
                                   hyperspectral_savedir=hyperspectral_savedir,
                                   saved_model=saved_model)
            futures.append(future)

        wait(futures)
        for x in futures:
            try:
                plot_crops, plot_labels, plot_box_index = x.result()
                print(plot_box_index[0])

                #Append to general plot list
                crops.extend(plot_crops)
                labels.extend(plot_labels)
                box_indexes.extend(plot_box_index)
            except Exception as e:
                print("Future failed with {}".format(e))
    else:
        for plot in plot_names:
            plot_crops, plot_labels, plot_box_index = run(
                plot=plot,
                df=df,
                rgb_pool=rgb_pool,
                hyperspectral_pool=hyperspectral_pool,
                sensor=sensor,
                extend_box=extend_box,
                hyperspectral_savedir=hyperspectral_savedir,
                saved_model=saved_model)

            #Append to general plot list
            crops.extend(plot_crops)
            labels.extend(plot_labels)
            box_indexes.extend(plot_box_index)

    if shuffle:
        z = list(zip(crops, box_indexes, labels))
        random.shuffle(z)
        crops, box_indexes, labels = zip(*z)

    #Convert labels to numeric
    unique_labels = np.unique(labels)
    label_dict = {}

    for index, label in enumerate(unique_labels):
        label_dict[label] = index

    numeric_labels = [label_dict[x] for x in labels]
    pd.DataFrame(label_dict.items(),
                 columns=["taxonID", "label"
                          ]).to_csv("{}/class_labels.csv".format(savedir))

    #Write tfrecords
    tfrecords = create_records(crops,
                               numeric_labels,
                               box_indexes,
                               savedir,
                               height,
                               width,
                               chunk_size=chunk_size)

    return tfrecords