Beispiel #1
0
def test_torchvision_registry_models(key: str, pretrained: Union[bool, str],
                                     constructor: Callable):
    model = ModelRegistry.create(key, pretrained)
    diff_model = constructor(pretrained=False)
    compare_model(model, diff_model, same=False)

    if pretrained is True:
        # check torchvision weights are properly loaded
        match_model = constructor(pretrained=pretrained)
        compare_model(model, match_model, same=True)
Beispiel #2
0
def test_ssd_resnsets(
    key: str,
    pretrained: Union[bool, str],
    pretrained_backbone: Union[bool, str],
    test_input: bool,
    match_const: Callable,
):
    model = ModelRegistry.create(key, pretrained)
    diff_model = match_const(pretrained_backbone=pretrained_backbone)

    if pretrained:
        compare_model(model, diff_model, same=False)
        match_model = ModelRegistry.create(key, pretrained)
        compare_model(model, match_model, same=True)

    if pretrained_backbone and pretrained_backbone is not True:
        compare_model(model.feature_extractor, diff_model.feature_extractor, same=False)
        match_model = ModelRegistry.create(key, pretrained_backbone=pretrained_backbone)
        compare_model(
            diff_model.feature_extractor, match_model.feature_extractor, same=True
        )

    if test_input:
        input_shape = ModelRegistry.input_shape(key)
        batch = torch.randn(1, *input_shape)
        model.eval()
        boxes, scores = model(batch)
        assert isinstance(boxes, torch.Tensor)
        assert isinstance(scores, torch.Tensor)
        assert boxes.dim() == 3
        assert scores.dim() == 3
        assert boxes.size(0) == 1
        assert boxes.size(1) == 4
        assert scores.size(0) == 1
        assert boxes.size(2) == scores.size(2)  # check same num default boxes
Beispiel #3
0
def test_yolo_v3(
    key: str,
    pretrained: Union[bool, str],
    pretrained_backbone: Union[bool, str],
    test_input: bool,
    match_const: Callable,
):
    model = ModelRegistry.create(key, pretrained)
    diff_model = match_const(pretrained_backbone=pretrained_backbone)

    if pretrained:
        compare_model(model, diff_model, same=False)
        match_model = ModelRegistry.create(key, pretrained)
        compare_model(model, match_model, same=True)

    if pretrained_backbone and pretrained_backbone is not True:
        compare_model(model.backbone, diff_model.backbone, same=False)
        match_model = ModelRegistry.create(
            key, pretrained_backbone=pretrained_backbone)
        compare_model(diff_model.backbone, match_model.backbone, same=True)

    if test_input:
        input_shape = ModelRegistry.input_shape(key)
        batch = torch.randn(1, *input_shape)
        model.eval()
        outputs = model(batch)
        assert isinstance(outputs, list)
        for output in outputs:
            assert isinstance(output, torch.Tensor)
            assert output.dim() == 5
            assert output.size(-1) == 85
Beispiel #4
0
def test_mnist(key: str, pretrained: Union[bool, str], test_input: bool):
    model = ModelRegistry.create(key, pretrained)
    diff_model = mnist_net()

    if pretrained:
        compare_model(model, diff_model, same=False)
        match_model = ModelRegistry.create(key, pretrained)
        compare_model(model, match_model, same=True)

    if test_input:
        input_shape = ModelRegistry.input_shape(key)
        batch = torch.randn(1, *input_shape)
        out = model(batch)
        assert isinstance(out, tuple)
        for tens in out:
            assert tens.shape[0] == 1
            assert tens.shape[1] == 10
Beispiel #5
0
def test_efficientnet(
    key: str,
    pretrained: Union[bool, str],
    test_input: bool,
    match_const: Callable,
    model_args: dict,
):
    model = ModelRegistry.create(key, pretrained, **model_args)
    diff_model = match_const(**model_args)

    if pretrained:
        compare_model(model, diff_model, same=False)
        match_model = ModelRegistry.create(key, pretrained, **model_args)
        compare_model(model, match_model, same=True)

    if test_input:
        input_shape = ModelRegistry.input_shape(key)
        batch = torch.randn(1, *input_shape)
        model = model.eval()
        out = model(batch)
        assert isinstance(out, tuple)
        for tens in out:
            assert tens.shape[0] == 1
            assert tens.shape[1] == 1000