예제 #1
0
def test_torch_image_classification_custom_net():
    from gluoncv.auto.tasks import ImageClassification
    from timm import create_model
    import torch.nn as nn
    net = create_model('resnet18')
    net.fc = nn.Linear(512, 4)
    task = ImageClassification({'num_trials': 1, 'epochs': 1, 'custom_net': net, 'batch_size': 8})
    classifier = task.fit(IMAGE_CLASS_DATASET)
    assert task.fit_summary().get('valid_acc', 0) > 0
    test_result = classifier.predict(IMAGE_CLASS_TEST)
예제 #2
0
def test_image_classification():
    from gluoncv.auto.tasks import ImageClassification
    task = ImageClassification({
        'model': 'resnet18_v1',
        'num_trials': 1,
        'epochs': 1,
        'batch_size': 8
    })
    classifier = task.fit(IMAGE_CLASS_DATASET)
    assert task.fit_summary().get('valid_acc', 0) > 0
    test_result = classifier.predict(IMAGE_CLASS_TEST)
예제 #3
0
def test_image_classification_custom_net():
    from gluoncv.auto.tasks import ImageClassification
    from gluoncv.model_zoo import get_model
    net = get_model('resnet18_v1')
    task = ImageClassification({
        'num_trials': 1,
        'epochs': 1,
        'custom_net': net
    })
    classifier = task.fit(IMAGE_CLASS_DATASET)
    assert task.fit_summary().get('valid_acc', 0) > 0
    test_result = classifier.predict(IMAGE_CLASS_TEST)
예제 #4
0
def test_image_classification():
    from gluoncv.auto.tasks import ImageClassification
    task = ImageClassification({'num_trials': 1})
    classifier = task.fit(IMAGE_CLASS_DATASET)
    assert task.fit_summary().get('valid_acc', 0) > 0
    test_result = classifier.predict(IMAGE_CLASS_TEST)