예제 #1
0
def test_set_yolo_parameters():
    model = dpp.ObjectDetectionModel()
    with pytest.raises(RuntimeError):
        model.set_yolo_parameters()
    model.set_image_dimensions(448, 448, 3)
    model.set_yolo_parameters()

    with pytest.raises(TypeError):
        model.set_yolo_parameters(True, ['plant', 'knat'],
                                  [(100, 30), (200, 10), (50, 145)])
    with pytest.raises(TypeError):
        model.set_yolo_parameters(13, ['plant', 'knat'], [(100, 30), (200, 10),
                                                          (50, 145)])
    with pytest.raises(TypeError):
        model.set_yolo_parameters([13], ['plant', 'knat'],
                                  [(100, 30), (200, 10), (50, 145)])
    with pytest.raises(TypeError):
        model.set_yolo_parameters([13, 13], 'plant', [(100, 30), (200, 10),
                                                      (50, 145)])
    with pytest.raises(TypeError):
        model.set_yolo_parameters([13, 13], ['plant', 2],
                                  [(100, 30), (200, 10), (50, 145)])
    with pytest.raises(TypeError):
        model.set_yolo_parameters([13, 13], ['plant', 'knat'], 100)
    with pytest.raises(TypeError):
        model.set_yolo_parameters([13, 13], ['plant', 'knat'], [(100, 30),
                                                                (200, 10), 50])
    with pytest.raises(TypeError):
        model.set_yolo_parameters([13, 13], ['plant', 'knat'], [(100, 30),
                                                                (200, 10),
                                                                (145, )])
    with pytest.raises(TypeError):
        model.set_yolo_parameters([13, 13], ['plant', 'knat'], [(100, 30),
                                                                (200, 10),
                                                                (145, 'a')])
    model.set_yolo_parameters([13, 13], ['plant', 'knat'],
                              [(100, 30), (200, 10), (50, 145)])
예제 #2
0
    with pytest.raises(TypeError):
        model.set_patch_size(1.0, 1)
    with pytest.raises(ValueError):
        model.set_patch_size(-1, 1)
    with pytest.raises(TypeError):
        model.set_patch_size(1, 1.0)
    with pytest.raises(ValueError):
        model.set_patch_size(1, -1)


@pytest.mark.parametrize(
    "model,bad_loss,good_loss",
    [(dpp.ClassificationModel(), 'l2', 'softmax cross entropy'),
     (dpp.RegressionModel(), 'softmax cross entropy', 'l2'),
     (dpp.SemanticSegmentationModel(), 'l2', 'sigmoid cross entropy'),
     (dpp.ObjectDetectionModel(), 'l2', 'yolo'),
     (dpp.CountCeptionModel(), 'l2', 'l1'),
     (dpp.HeatmapObjectCountingModel(), 'l1', 'sigmoid cross entropy')])
def test_set_loss_function(model, bad_loss, good_loss):
    with pytest.raises(TypeError):
        model.set_loss_function(0)
    with pytest.raises(ValueError):
        model.set_loss_function(bad_loss)
    model.set_loss_function(good_loss)


def test_set_yolo_parameters():
    model = dpp.ObjectDetectionModel()
    with pytest.raises(RuntimeError):
        model.set_yolo_parameters()
    model.set_image_dimensions(448, 448, 3)
#
# Demonstrates the process of training a YOLO-based object detector in DPP.
#

import deepplantphenomics as dpp

model = dpp.ObjectDetectionModel(debug=True,
                                 save_checkpoints=False,
                                 tensorboard_dir='tensor_logs',
                                 report_rate=20)

# 3 channels for colour, 1 channel for greyscale
channels = 3

# Setup and hyper-parameters
model.set_batch_size(1)
model.set_number_of_threads(4)
model.set_image_dimensions(448, 448, channels)
model.set_resize_images(False)
model.set_patch_size(448, 448)

# model.set_yolo_parameters() is not called here because we are using all of the default values
model.set_test_split(0.1)
model.set_validation_split(0)
model.set_learning_rate(0.000001)
model.set_weight_initializer('xavier')
model.set_maximum_training_epochs(100)

model.load_yolo_dataset_from_directory('./yolo_data',
                                       label_file='labels.json',
                                       image_dir='images')