Ejemplo n.º 1
0
def test_reload_model(test_dataset, test_anchors, test_masks, test_classes):
    img_shape = test_dataset.target_shape
    model = YoloV3(img_shape,
                   test_dataset.max_objects,
                   backbone='MobileNetV2',
                   anchors=test_anchors,
                   num_classes=len(test_classes),
                   training=True)

    loss_fn = losses.make_loss(model.num_classes, test_anchors, test_masks,
                               img_shape[0], len(test_dataset))
    optimizer = model.get_optimizer('adam', 1e-4)
    model.compile(optimizer, loss_fn, run_eagerly=True)
    model.fit(test_dataset, test_dataset, 1)

    save_path = BASE_PATH.parent / 'model.h5'
    model.save(save_path)

    del model

    model = YoloV3(img_shape,
                   test_dataset.max_objects,
                   backbone='MobileNetV2',
                   anchors=test_anchors,
                   num_classes=len(test_classes),
                   training=False)
    model.load_weights(save_path)

    _, _, _, valid_detections = model.predict(tf.zeros((1, *img_shape)))
    assert len(valid_detections) > 0
Ejemplo n.º 2
0
def test_model_resnet50(test_dataset, test_anchors, test_masks, test_classes):
    img_shape = test_dataset.target_shape
    model = YoloV3(img_shape,
                   test_dataset.max_objects,
                   backbone='ResNet50V2',
                   anchors=test_anchors,
                   num_classes=len(test_classes),
                   training=True)

    loss_fn = losses.make_loss(model.num_classes, test_anchors, test_masks,
                               img_shape[0], len(test_dataset))
    optimizer = model.get_optimizer('sgd', 1e-4)
    model.compile(optimizer, loss_fn, run_eagerly=True)
    res = model(test_dataset[0][0])
    assert res is not None
Ejemplo n.º 3
0
def test_model_DenseNet121_grid608_64(test_dataset_grid608_64, test_anchors,
                                      test_masks, test_classes):
    img_shape = test_dataset_grid608_64.target_shape
    model = YoloV3(img_shape,
                   test_dataset_grid608_64.max_objects,
                   backbone='DenseNet121',
                   anchors=test_anchors,
                   num_classes=len(test_classes),
                   training=True)

    loss_fn = losses.make_loss(model.num_classes, test_anchors, test_masks,
                               img_shape[0], len(test_dataset_grid608_64))
    optimizer = model.get_optimizer('sgd', 1e-4)
    model.compile(optimizer, loss_fn, run_eagerly=True)
    res = model(test_dataset_grid608_64[0][0])
    assert res is not None