コード例 #1
0
def run(plot,
        df,
        savedir,
        raw_box_savedir,
        rgb_pool=None,
        saved_model=None,
        deepforest_model=None):
    """wrapper function for dask, see main.py"""
    from deepforest import deepforest

    #create deepforest model
    if deepforest_model is None:
        if saved_model is None:
            deepforest_model = deepforest.deepforest()
            deepforest_model.use_release()
        else:
            deepforest_model = deepforest.deepforest(saved_model=saved_model)

    #Filter data and process
    plot_data = df[df.plotID == plot]
    predicted_trees, raw_boxes = process_plot(plot_data, rgb_pool,
                                              deepforest_model)

    #Write merged boxes to file as an interim piece of data to inspect.
    predicted_trees.to_file("{}/{}_boxes.shp".format(savedir, plot))
    raw_boxes.to_file("{}/{}_boxes.shp".format(raw_box_savedir, plot))
コード例 #2
0
def test_multi_train(multi_annotations, tmpdir):
    test_model = deepforest.deepforest()
    test_model.use_release()
    test_model.config["epochs"] = 1
    test_model.config["save-snapshot"] = False
    test_model.config["steps"] = 5
    test_model.train(annotations=multi_annotations, input_type="fit_generator")

    boxes = test_model.predict_generator(annotations=multi_annotations)

    # Test labels
    labels = list(test_model.labels.values())
    labels.sort()
    target_labels = ["Dead", "Alive"]
    target_labels.sort()

    assert labels == target_labels

    #Test reload models
    test_model.model.save("{}/prediction_model.h5".format(tmpdir))

    # new object
    new_session = deepforest.deepforest(
        saved_model="{}/prediction_model.h5".format(tmpdir))
    new_boxes = new_session.predict_generator(annotations=multi_annotations)

    assert labels == new_boxes.labels.unique()
    pd.testing.assert_frame_equal(boxes, new_boxes)
コード例 #3
0
def run(plot,
        df,
        rgb_pool=None,
        hyperspectral_pool=None,
        sensor="hyperspectral",
        extend_box=0,
        hyperspectral_savedir=".",
        saved_model=None):
    """wrapper function for dask, see main.py"""
    try:
        from deepforest import deepforest

        #create deepforest model
        if saved_model is None:
            deepforest_model = deepforest.deepforest()
            deepforest_model.use_release()
        else:
            deepforest_model = deepforest.deepforest(saved_model=saved_model)

        #Filter data and process
        plot_data = df[df.plotID == plot]
        predicted_trees = process_plot(plot_data, rgb_pool, deepforest_model)
        plot_crops, plot_labels, plot_box_index = create_crops(
            predicted_trees,
            hyperspectral_pool=hyperspectral_pool,
            rgb_pool=rgb_pool,
            sensor=sensor,
            expand=extend_box,
            hyperspectral_savedir=hyperspectral_savedir)
    except:
        print("Plot {} failed".format(plot))
        raise

    return plot_crops, plot_labels, plot_box_index
コード例 #4
0
def load_model(weights=None):
    if weights:
        model = deepforest.deepforest(weights=weights)
    else:
        model = deepforest.deepforest()
        model.use_release()

    return model
コード例 #5
0
def load_model():
    try:
        model = deepforest.deepforest(saved_model="deepforest-model.h5")
        print("Trained Model loaded")
    except:
        model = deepforest.deepforest()
        model.use_release()
        print("Prebuilt Model loaded")
    return model
コード例 #6
0
def run(plot,
        df,
        rgb_pool=None,
        hyperspectral_pool=None,
        extend_HSI_box=0,
        extend_RGB_box=0,
        hyperspectral_savedir=".",
        saved_model=None,
        deepforest_model=None):
    """wrapper function for dask, see main.py"""
    from deepforest import deepforest

    #create deepforest model
    if deepforest_model is None:
        if saved_model is None:
            deepforest_model = deepforest.deepforest()
            deepforest_model.use_release()
        else:
            deepforest_model = deepforest.deepforest(saved_model=saved_model)

    #Filter data and process
    plot_data = df[df.plotID == plot]
    predicted_trees = process_plot(plot_data,
                                   rgb_pool,
                                   deepforest_model,
                                   debug=True)

    #Write merged boxes to file as an interim piece of data to inspect.
    interim_dir = os.path.abspath(ROOT)
    predicted_trees.to_file("{}/data/interim/{}_boxes.shp".format(
        interim_dir, plot))

    #Crop HSI
    plot_HSI_crops, plot_labels, plot_domains, plot_sites, plot_heights, plot_elevations, plot_box_index = create_crops(
        predicted_trees,
        hyperspectral_pool=hyperspectral_pool,
        rgb_pool=rgb_pool,
        sensor="hyperspectral",
        expand=extend_HSI_box,
        hyperspectral_savedir=hyperspectral_savedir)

    #Crop RGB, drop repeated elements, leave one for testing
    plot_rgb_crops, plot_rgb_labels, _, _, _, _, _ = create_crops(
        predicted_trees,
        hyperspectral_pool=hyperspectral_pool,
        rgb_pool=rgb_pool,
        sensor="rgb",
        expand=extend_RGB_box,
        hyperspectral_savedir=hyperspectral_savedir)

    #Assert they are the same
    assert len(plot_rgb_crops) == len(plot_HSI_crops)
    assert plot_labels == plot_rgb_labels

    return plot_HSI_crops, plot_rgb_crops, plot_labels, plot_domains, plot_sites, plot_heights, plot_elevations, plot_box_index
コード例 #7
0
def test_random_transform(annotations):
    test_model = deepforest.deepforest()
    test_model.config["random_transform"] = True
    classes_file = utilities.create_classes(annotations)
    arg_list = utilities.format_args(annotations, classes_file,
                                     test_model.config)
    assert "--random-transform" in arg_list
コード例 #8
0
def from_retinanet(annotations_csv):
    model = deepforest.deepforest()
    model.use_release()
    """use the keras retinanet source to create detections"""
    ### Keras retinanet
    # Format args for CSV generator
    classes_file = utilities.create_classes(annotations_csv)
    arg_list = utilities.format_args(annotations_csv, classes_file,
                                     model.config)
    args = parse_args(arg_list)

    # create generator
    validation_generator = csv_generator.CSVGenerator(
        args.annotations,
        args.classes,
        image_min_side=args.image_min_side,
        image_max_side=args.image_max_side,
        config=args.config,
        shuffle_groups=False,
    )

    all_detections = _get_detections(validation_generator,
                                     model.prediction_model,
                                     score_threshold=args.score_threshold,
                                     max_detections=100)
    all_annotations = _get_annotations(validation_generator)

    return all_detections, all_annotations, validation_generator
コード例 #9
0
def test_empty_plot():
    #DeepForest prediction
    deepforest_model = deepforest.deepforest()
    deepforest_model.use_release()
    plot_data = gpd.read_file(data_path)
    rgb_sensor_path = prepare_field_data.find_sensor_path(
        bounds=plot_data.total_bounds, lookup_pool=rgb_pool)
    boxes = prepare_field_data.predict_trees(deepforest_model=deepforest_model,
                                             rgb_path=rgb_sensor_path,
                                             bounds=plot_data.total_bounds)

    #fake offset boxes by adding a scalar to the geometry
    boxes["geometry"] = boxes["geometry"].translate(100000)

    #Merge results with field data, buffer on edge
    merged_boxes = gpd.sjoin(boxes, plot_data)

    #If no remaining boxes just take a box around center
    if merged_boxes.empty:
        merged_boxes = prepare_field_data.create_boxes(plot_data)

    #If there are multiple boxes, take the center box
    grouped = merged_boxes.groupby("individual")

    cleaned_boxes = []
    for value, group in grouped:
        choosen_box = prepare_field_data.choose_box(group, plot_data)
        cleaned_boxes.append(choosen_box)

    merged_boxes = gpd.GeoDataFrame(pd.concat(cleaned_boxes),
                                    crs=merged_boxes.crs)
    merged_boxes = merged_boxes.drop(columns=["xmin", "xmax", "ymin", "ymax"])
コード例 #10
0
def test_process_plot():
    df = gpd.read_file(data_path)
    
    deepforest_model = deepforest.deepforest()
    deepforest_model.use_release()
    
    merged_boxes = prepare_field_data.process_plot(plot_data=df, rgb_pool=rgb_pool, deepforest_model=deepforest_model)
    assert df.shape[0] <= merged_boxes.shape[0]
コード例 #11
0
def test_train(annotations):
    test_model = deepforest.deepforest()
    test_model.config["epochs"] = 1
    test_model.config["save-snapshot"] = False
    test_model.config["steps"] = 1
    test_model.train(annotations=annotations, input_type="fit_generator")

    return test_model
コード例 #12
0
def run(plot, df, rgb_pool=None, hyperspectral_pool=None, extend_box=0, hyperspectral_savedir=".",saved_model=None, deepforest_model=None):
    """wrapper function for dask, see main.py"""
    try:
        from deepforest import deepforest
    
        #create deepforest model
        if deepforest_model is None:
            if saved_model is None:
                deepforest_model = deepforest.deepforest()
                deepforest_model.use_release()
            else:
                deepforest_model = deepforest.deepforest(saved_model=saved_model)
            
        #Filter data and process
        plot_data = df[df.plotID == plot]
        predicted_trees = process_plot(plot_data, rgb_pool, deepforest_model)
        
        #Crop HSI
        plot_HSI_crops, plot_labels, plot_sites, plot_elevations, plot_box_index = create_crops(
            predicted_trees,
            hyperspectral_pool=hyperspectral_pool,
            rgb_pool=rgb_pool,
            sensor="hyperspectral",
            expand=extend_box,
            hyperspectral_savedir=hyperspectral_savedir)
        
        #Crop RGB, drop repeated elements, leave one for testing
        plot_rgb_crops, plot_rgb_labels, _, _, _ = create_crops(
            predicted_trees,
            hyperspectral_pool=hyperspectral_pool,
            rgb_pool=rgb_pool,
            sensor="rgb",
            expand=extend_box,
            hyperspectral_savedir=hyperspectral_savedir)    
        
        #Assert they are the same
        assert len(plot_rgb_crops) == len(plot_HSI_crops)
        assert plot_labels==plot_rgb_labels
    except Exception as e:
        print("Plot {} failed {}".format(plot, e))
        raise
        
    return plot_HSI_crops, plot_rgb_crops, plot_labels, plot_sites, plot_elevations, plot_box_index
コード例 #13
0
def test_predict_image(download_release):
    test_model = deepforest.deepforest(
        weights="tests/data/universal_model_july30.h5")
    assert isinstance(test_model.model, keras.models.Model)
    boxes = test_model.predict_image(image_path="tests/data/OSBS_029.tif",
                                     show=False,
                                     return_plot=False)

    #Returns a 4 column numpy array
    assert isinstance(boxes, np.ndarray)
    assert boxes.shape[1] == 4
コード例 #14
0
def release_model(download_release):
    test_model = deepforest.deepforest()
    test_model.use_release()

    # Check for release tag
    assert isinstance(test_model.__release_version__, str)

    # Assert is model instance
    assert isinstance(test_model.model, keras.models.Model)

    return test_model
コード例 #15
0
def test_use_release(download_release):
    test_model = deepforest.deepforest()
    test_model.use_release()

    # Check for release tag
    assert isinstance(test_model.__release_version__, str)

    # Assert is model instance
    assert isinstance(test_model.model, keras.models.Model)

    assert test_model.config["weights"] == test_model.weights
    assert test_model.config["weights"] is not "None"
コード例 #16
0
def from_source(annotations_csv):
    """Use the code in this repo to get detections"""
    model = deepforest.deepforest()
    model.use_release()
    predicted_boxes = model.predict_generator(annotations_csv)
    true_boxes = pd.read_csv(
        annotations_csv,
        names=["image_path", "xmin", "ymin", "xmax", "ymax", "label"])
    true_boxes["plot_name"] = true_boxes.image_path.apply(
        lambda x: os.path.splitext(x)[0])

    return predicted_boxes, true_boxes
コード例 #17
0
def main(field_data,
         rgb_dir,
         savedir,
         raw_box_savedir,
         saved_model=None,
         client=None,
         shuffle=True):
    """Prepare NEON field data into tfrecords
    Args:
        field_data: shp file with location and class of each field collected point
        height: height in meters of the resized training image
        width: width in meters of the resized training image
        savedir: direcory to save predicted bounding boxes
        raw_box_savedir: directory save all bounding boxes in the image
        client: dask client object to use
    Returns:
        None: .shp bounding boxes are written to savedir
    """
    df = gpd.read_file(field_data)
    plot_names = df.plotID.unique()

    rgb_pool = glob.glob(rgb_dir, recursive=True)

    if client is not None:
        futures = []
        for plot in plot_names:
            future = client.submit(run,
                                   plot=plot,
                                   df=df,
                                   rgb_pool=rgb_pool,
                                   saved_model=saved_model,
                                   savedir=savedir,
                                   raw_box_savedir=raw_box_savedir)
            futures.append(future)

        wait(futures)

    else:
        from deepforest import deepforest
        deepforest_model = deepforest.deepforest()
        deepforest_model.use_release()
        for plot in plot_names:
            try:
                run(plot=plot,
                    df=df,
                    rgb_pool=rgb_pool,
                    saved_model=saved_model,
                    deepforest_model=deepforest_model,
                    raw_box_savedir=raw_box_savedir)
            except Exception as e:
                print("Plot failed with {}".format(e))
                traceback.print_exc()
                continue
コード例 #18
0
def test_multi_train(multi_annotations):
    test_model = deepforest.deepforest()
    test_model.config["epochs"] = 1
    test_model.config["save-snapshot"] = False
    test_model.config["steps"] = 1
    test_model.train(annotations=multi_annotations, input_type="fit_generator")

    # Test labels
    labels = list(test_model.labels.values())
    labels.sort()
    target_labels = ["Dead", "Alive"]
    target_labels.sort()

    assert labels == target_labels
コード例 #19
0
def test_reload_weights(release_model):
    release_model.model.save_weights("tests/output/example_saved_weights.h5")
    reloaded = deepforest.deepforest(
        weights="tests/output/example_saved_weights.h5")
    assert reloaded.prediction_model

    # Predict test image and return boxes
    boxes = reloaded.predict_image(image_path=get_data("OSBS_029.tif"),
                                   show=False,
                                   return_plot=False,
                                   score_threshold=0.1)

    # Returns a 6 column numpy array, xmin, ymin, xmax, ymax, score, label
    assert boxes.shape[1] == 6
コード例 #20
0
def test_predict_image(download_release):
    # Load model
    test_model = deepforest.deepforest(weights=get_data("NEON.h5"))
    assert isinstance(test_model.model, keras.models.Model)

    # Predict test image and return boxes
    boxes = test_model.predict_image(image_path=get_data("OSBS_029.tif"),
                                     show=False,
                                     return_plot=False,
                                     score_threshold=0.1)

    # Returns a 6 column numpy array, xmin, ymin, xmax, ymax, score, label
    assert boxes.shape[1] == 6

    assert boxes.score.min() > 0.1
コード例 #21
0
def submission_no_chm(tiles_to_predict):
    #Predict
    results = []
    model = deepforest.deepforest()
    model.use_release()    
    for path in tiles_to_predict:   
        try:
            result = model.predict_image(path,return_plot=False)    
            result["plot_name"] = os.path.splitext(os.path.basename(path))[0]
            results.append(result)
        except Exception as e:
            print(e)
            continue    
        
    #Create plot name groups
    boxes = pd.concat(results)
    
    return boxes
コード例 #22
0
def test_mAP_deepforest(annotations_csv):
    model = deepforest.deepforest()
    model.use_release()

    #Original retinanet implementation
    mAP_retinanet = model.evaluate_generator(annotations=annotations_csv)

    #This repo implementation
    predicted_boxes = model.predict_generator(annotations_csv)
    true_boxes = pd.read_csv(
        annotations_csv,
        names=["image_path", "xmin", "ymin", "xmax", "ymax", "label"])
    true_boxes["plot_name"] = true_boxes.image_path.apply(
        lambda x: os.path.splitext(x)[0])
    mAP = average_precision.calculate_mAP(true_boxes,
                                          predicted_boxes,
                                          iou_threshold=0.5)

    assert mAP_retinanet == mAP
コード例 #23
0
def main(
    field_data,
    RGB_size,
    HSI_size,
    rgb_dir, 
    hyperspectral_dir,
    savedir=".", 
    chunk_size=400,
    extend_box=0, 
    hyperspectral_savedir=".", 
    n_workers=20,
    saved_model=None, 
    use_dask=False, 
    shuffle=True,
    species_classes_file=None,
    site_classes_file=None):
    """Prepare NEON field data into tfrecords
    Args:
        field_data: shp file with location and class of each field collected point
        height: height in meters of the resized training image
        width: width in meters of the resized training image
        savedir: direcory to save completed tfrecords
        extend_box: units in meters to add to the edge of a predicted box to give more context
        hyperspectral_savedir: location to save converted .h5 to .tif
        n_workers: number of dask workers
        species_classes_file: optional path to a two column csv file with index and species labels
        site_classes_file: optional path to a two column csv file with index and site labels
        shuffle: shuffle lists before writing
    Returns:
        tfrecords: list of created tfrecords
    """ 
    df = gpd.read_file(field_data)
    plot_names = df.plotID.unique()
    
    hyperspectral_pool = glob.glob(hyperspectral_dir, recursive=True)
    rgb_pool = glob.glob(rgb_dir, recursive=True)
    
    labels = []
    HSI_crops = []
    RGB_crops = []
    sites = []
    box_indexes = []    
    elevations = []
    if use_dask:
        client = start_cluster.start(cpus=n_workers, mem_size="10GB")
        futures = []
        for plot in plot_names:
            future = client.submit(
                run,
                plot=plot,
                df=df,
                rgb_pool=rgb_pool,
                hyperspectral_pool=hyperspectral_pool,
                extend_box=extend_box,
                hyperspectral_savedir=hyperspectral_savedir,
                saved_model=saved_model
            )
            futures.append(future)
        
        wait(futures)
        for x in futures:
            try:
                plot_HSI_crops, plot_RGB_crops, plot_labels, plot_sites, plot_elevations, plot_box_index = x.result()
                
                #Append to general plot list
                HSI_crops.extend(plot_HSI_crops)
                RGB_crops.extend(plot_RGB_crops)
                labels.extend(plot_labels)
                sites.extend(plot_sites)            
                elevations.extend(plot_elevations)
                box_indexes.extend(plot_box_index)        
            except Exception as e:
                print("Future failed with {}".format(e))      
                traceback.print_exc()
    else:
        from deepforest import deepforest        
        deepforest_model = deepforest.deepforest()
        deepforest_model.use_release()        
        for plot in plot_names:
            try:
                plot_HSI_crops, plot_RGB_crops, plot_labels, plot_sites, plot_elevations, plot_box_index = run(
                    plot=plot,
                    df=df,
                    rgb_pool=rgb_pool,
                    hyperspectral_pool=hyperspectral_pool, 
                    extend_box=extend_box,
                    hyperspectral_savedir=hyperspectral_savedir,
                    saved_model=saved_model,
                    deepforest_model=deepforest_model
                )
            except Exception as e:
                print("Plot failed with {}".format(e))      
                traceback.print_exc()  
                continue
    
            #Append to general plot list
            HSI_crops.extend(plot_HSI_crops)
            RGB_crops.extend(plot_RGB_crops)
            labels.extend(plot_labels)
            sites.extend(plot_sites)            
            elevations.extend(plot_elevations)
            box_indexes.extend(plot_box_index)
            


    if shuffle:
        z = list(zip(HSI_crops, RGB_crops, sites, elevations, box_indexes, labels))
        random.shuffle(z)
        HSI_crops, RGB_crops, sites, elevations, box_indexes, labels = zip(*z)
                        
    #If passes a species label dict
    if species_classes_file is not None:
        species_classdf  = pd.read_csv(species_classes_file)
        species_label_dict = species_classdf.set_index("taxonID").label.to_dict()
    else:
        #Create and save a new species and site label dict
        unique_species_labels = np.unique(labels)
        species_label_dict = {}
        
        for index, label in enumerate(unique_species_labels):
            species_label_dict[label] = index
        pd.DataFrame(species_label_dict.items(), columns=["taxonID","label"]).to_csv("{}/species_class_labels.csv".format(savedir))

    #If passes a site label dict
    if site_classes_file is not None:
        site_classdf  = pd.read_csv(site_classes_file)
        site_label_dict = site_classdf.set_index("siteID").label.to_dict()
    else:
        #Create and save a new site and site label dict
        unique_site_labels = np.unique(sites)
        site_label_dict = {}
        
        for index, label in enumerate(unique_site_labels):
            site_label_dict[label] = index
        pd.DataFrame(site_label_dict.items(), columns=["siteID","label"]).to_csv("{}/site_class_labels.csv".format(savedir))

    #Convert labels to numeric
    numeric_labels = [species_label_dict[x] for x in labels]
    numeric_sites = [site_label_dict[x] for x in sites]
    
    #Write tfrecords
    tfrecords = create_records(
        HSI_crops=HSI_crops,
        RGB_crops=RGB_crops,
        labels=numeric_labels, 
        sites=numeric_sites, 
        elevations=elevations,
        box_index=box_indexes, 
        savedir=savedir, 
        RGB_size=RGB_size,
        HSI_size=HSI_size, 
        chunk_size=chunk_size)
    
    return tfrecords
コード例 #24
0
def test_use_release(download_release):
    test_model = deepforest.deepforest()
    test_model.use_release()
    #Assert is model instance
    assert isinstance(test_model.model, keras.models.Model)
コード例 #25
0
def test_train(annotations):
    test_model = deepforest.deepforest()
    test_model.config["epochs"] = 1
    test_model.config["save-snapshot"] = False
    test_model.train(annotations="tests/data/OSBS_029.csv")
コード例 #26
0
def test_deepforest():
    model = deepforest.deepforest(weights=None)
    assert model.weights is None
コード例 #27
0
tfrecords_path = tfrecords.create_tfrecords(annotations_file,
                                            class_file,
                                            size=1)
print("Created {} tfrecords: {}".format(len(tfrecords_path), tfrecords_path))
inputs, targets = tfrecords.create_tensors(tfrecords_path)

#### Fit generator ##
comet_experiment = Experiment(api_key="ypQZhYfs3nSyKzOfz13iuJpj2",
                              project_name="deepforest",
                              workspace="bw4sz")

comet_experiment.log_parameter("Type", "testing")
comet_experiment.log_parameter("input_type", "fit_generator")

#Create model
fitgen_model = deepforest.deepforest()
fitgen_model.config["epochs"] = 1
comet_experiment.log_parameters(fitgen_model.config)

#Train model
fitgen_model.train(annotations_file,
                   input_type="fit_generator",
                   comet_experiment=comet_experiment,
                   images_per_epoch=1000)

#Evaluate on original annotations
mAP = fitgen_model.evaluate_generator(annotations_file, comet_experiment)
boxes = fitgen_model.prediction_model.predict(inputs, steps=1)

comet_experiment.log_metric("mAP", mAP)
コード例 #28
0
from typing import Any, Dict, List, Optional

from _create_xml import create_label_xml

from deepforest import deepforest
import pandas as pd

SCRIPT_DIR = os.path.abspath(os.getcwd())
IMAGE_DIR_PATH = "../data/image_data"
LABEL_DIR_PATH = "../data/predictions"
MODEL_DIR_PATH = "../data/model_data"

IMAGE_WIDTH = 400
IMAGE_HEIGHT = 400

MODEL = deepforest.deepforest(saved_model=f"{MODEL_DIR_PATH}/model.h5")


def _predict_image(image_path: str) -> Optional[pd.DataFrame]:
    try:
        return MODEL.predict_image(f"{image_path}",
                                   show=False,
                                   return_plot=False)
    except:
        return None


def _create_label_dir(dir_path: str) -> None:
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    for sub_dir in ["csv", "xml"]:
コード例 #29
0
import glob

from comet_ml import Experiment
from deepforest import deepforest

#Local debug. If False, paths on UF hypergator supercomputing cluster
DEBUG = False
INPUT_TYPE = "tfrecord"

if DEBUG:
    BASE_PATH = "/Users/ben/Documents/NeonTreeEvaluation_analysis/Weinstein_unpublished/"
else:
    BASE_PATH = "/orange/ewhite/b.weinstein/NeonTreeEvaluation/"

##Pretrain deepforest on Silva annotations
deepforest_model = deepforest.deepforest()

# import comet_ml logger
comet_experiment = Experiment(api_key="ypQZhYfs3nSyKzOfz13iuJpj2",
                              project_name="deepforest",
                              workspace="bw4sz")

comet_experiment.log_parameters(deepforest_model.config)
comet_experiment.log_parameter("Type", "Pretraining")

if INPUT_TYPE == "fit_generator":
    comet_experiment.log_parameter("Profiler", "fit_generator")
    training_file = pd.read_csv(
        BASE_PATH + "pretraining/crops/pretraining.csv",
        names=["image_path", "xmin", "ymin", "xmax", "ymax", "label"])
    unique_images = training_file.image_path.unique()
コード例 #30
0
def test_predict_image_raise_error(download_release):
    #Predict test image and return boxes
    test_model = deepforest.deepforest(weights=get_data("NEON.h5"))
    with pytest.raises(ValueError):
        boxes = test_model.predict_image()