コード例 #1
0
def test_build_model_from_cfg():
    BACKBONES = mmcv.Registry('backbone', build_func=build_model_from_cfg)

    @BACKBONES.register_module()
    class ResNet(nn.Module):
        def __init__(self, depth, stages=4):
            super().__init__()
            self.depth = depth
            self.stages = stages

        def forward(self, x):
            return x

    @BACKBONES.register_module()
    class ResNeXt(nn.Module):
        def __init__(self, depth, stages=4):
            super().__init__()
            self.depth = depth
            self.stages = stages

        def forward(self, x):
            return x

    cfg = dict(type='ResNet', depth=50)
    model = BACKBONES.build(cfg)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(type='ResNeXt', depth=50, stages=3)
    model = BACKBONES.build(cfg)
    assert isinstance(model, ResNeXt)
    assert model.depth == 50 and model.stages == 3

    cfg = [
        dict(type='ResNet', depth=50),
        dict(type='ResNeXt', depth=50, stages=3)
    ]
    model = BACKBONES.build(cfg)
    assert isinstance(model, nn.Sequential)
    assert isinstance(model[0], ResNet)
    assert model[0].depth == 50 and model[0].stages == 4
    assert isinstance(model[1], ResNeXt)
    assert model[1].depth == 50 and model[1].stages == 3

    # test inherit `build_func` from parent
    NEW_MODELS = mmcv.Registry('models', parent=MODELS, scope='new')
    assert NEW_MODELS.build_func is build_model_from_cfg

    # test specify `build_func`
    def pseudo_build(cfg):
        return cfg

    NEW_MODELS = mmcv.Registry('models',
                               parent=MODELS,
                               build_func=pseudo_build)
    assert NEW_MODELS.build_func is pseudo_build
コード例 #2
0
ファイル: test_registry.py プロジェクト: zhouzaida/mmcv
def test_multi_scope_registry():
    DOGS = mmcv.Registry('dogs')
    assert DOGS.name == 'dogs'
    assert DOGS.scope == 'test_registry'
    assert DOGS.module_dict == {}
    assert len(DOGS) == 0

    @DOGS.register_module()
    class GoldenRetriever:
        pass

    assert len(DOGS) == 1
    assert DOGS.get('GoldenRetriever') is GoldenRetriever

    HOUNDS = mmcv.Registry('dogs', parent=DOGS, scope='hound')

    @HOUNDS.register_module()
    class BloodHound:
        pass

    assert len(HOUNDS) == 1
    assert HOUNDS.get('BloodHound') is BloodHound
    assert DOGS.get('hound.BloodHound') is BloodHound
    assert HOUNDS.get('hound.BloodHound') is BloodHound

    LITTLE_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='little_hound')

    @LITTLE_HOUNDS.register_module()
    class Dachshund:
        pass

    assert len(LITTLE_HOUNDS) == 1
    assert LITTLE_HOUNDS.get('Dachshund') is Dachshund
    assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound
    assert HOUNDS.get('little_hound.Dachshund') is Dachshund
    assert DOGS.get('hound.little_hound.Dachshund') is Dachshund

    MID_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='mid_hound')

    @MID_HOUNDS.register_module()
    class Beagle:
        pass

    assert MID_HOUNDS.get('Beagle') is Beagle
    assert HOUNDS.get('mid_hound.Beagle') is Beagle
    assert DOGS.get('hound.mid_hound.Beagle') is Beagle
    assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle
    assert MID_HOUNDS.get('hound.BloodHound') is BloodHound
    assert MID_HOUNDS.get('hound.Dachshund') is None
コード例 #3
0
def test_build_from_cfg():
    BACKBONES = mmcv.Registry('backbone')

    @BACKBONES.register_module
    class ResNet:
        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    @BACKBONES.register_module
    class ResNeXt:
        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    cfg = dict(type='ResNet', depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(type='ResNet', depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type='ResNeXt', depth=50, stages=3)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNeXt)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type=ResNet, depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    # non-registered class
    with pytest.raises(KeyError):
        cfg = dict(type='VGG')
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # cfg['type'] should be a str or class
    with pytest.raises(TypeError):
        cfg = dict(type=1000)
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # cfg should contain the key "type"
    with pytest.raises(TypeError):
        cfg = dict(depth=50, stages=4)
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # incorrect registry type
    with pytest.raises(TypeError):
        dict(type='ResNet', depth=50)
        model = mmcv.build_from_cfg(cfg, 'BACKBONES')

    # incorrect default_args type
    with pytest.raises(TypeError):
        dict(type='ResNet', depth=50)
        model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0)
コード例 #4
0
def test_registry():
    reg_name = 'cat'
    CATS = mmcv.Registry(reg_name)
    assert CATS.name == reg_name
    assert CATS.module_dict == {}
    assert len(CATS) == 0

    @CATS.register_module
    class BritishShorthair:
        pass

    assert len(CATS) == 1
    assert CATS.get('BritishShorthair') is BritishShorthair

    class Munchkin:
        pass

    CATS.register_module(Munchkin)
    assert len(CATS) == 2
    assert CATS.get('Munchkin') is Munchkin
    assert 'Munchkin' in CATS

    with pytest.raises(KeyError):
        CATS.register_module(Munchkin)

    CATS.register_module(Munchkin, force=True)
    assert len(CATS) == 2

    with pytest.raises(KeyError):

        @CATS.register_module
        class BritishShorthair:
            pass

    @CATS.register_module(force=True)
    class BritishShorthair:
        pass

    assert len(CATS) == 2

    assert CATS.get('PersianCat') is None
    assert 'PersianCat' not in CATS

    # The order of dict keys are not preserved in python 3.5
    assert repr(CATS) in [
        "Registry(name=cat, items=['BritishShorthair', 'Munchkin'])",
        "Registry(name=cat, items=['Munchkin', 'BritishShorthair'])"
    ]

    # the registered module should be a class
    with pytest.raises(TypeError):
        CATS.register_module(0)
コード例 #5
0
ファイル: test_registry.py プロジェクト: zhouzaida/mmcv
def test_registry():
    CATS = mmcv.Registry('cat')
    assert CATS.name == 'cat'
    assert CATS.module_dict == {}
    assert len(CATS) == 0

    @CATS.register_module()
    class BritishShorthair:
        pass

    assert len(CATS) == 1
    assert CATS.get('BritishShorthair') is BritishShorthair

    class Munchkin:
        pass

    CATS.register_module(Munchkin)
    assert len(CATS) == 2
    assert CATS.get('Munchkin') is Munchkin
    assert 'Munchkin' in CATS

    with pytest.raises(KeyError):
        CATS.register_module(Munchkin)

    CATS.register_module(Munchkin, force=True)
    assert len(CATS) == 2

    # force=False
    with pytest.raises(KeyError):

        @CATS.register_module()
        class BritishShorthair:
            pass

    @CATS.register_module(force=True)
    class BritishShorthair:
        pass

    assert len(CATS) == 2

    assert CATS.get('PersianCat') is None
    assert 'PersianCat' not in CATS

    @CATS.register_module(name=['Siamese', 'Siamese2'])
    class SiameseCat:
        pass

    assert CATS.get('Siamese').__name__ == 'SiameseCat'
    assert CATS.get('Siamese2').__name__ == 'SiameseCat'

    class SphynxCat:
        pass

    CATS.register_module(name='Sphynx', module=SphynxCat)
    assert CATS.get('Sphynx') is SphynxCat

    CATS.register_module(name=['Sphynx1', 'Sphynx2'], module=SphynxCat)
    assert CATS.get('Sphynx2') is SphynxCat

    repr_str = 'Registry(name=cat, items={'
    repr_str += ("'BritishShorthair': <class 'test_registry.test_registry."
                 "<locals>.BritishShorthair'>, ")
    repr_str += ("'Munchkin': <class 'test_registry.test_registry."
                 "<locals>.Munchkin'>, ")
    repr_str += ("'Siamese': <class 'test_registry.test_registry."
                 "<locals>.SiameseCat'>, ")
    repr_str += ("'Siamese2': <class 'test_registry.test_registry."
                 "<locals>.SiameseCat'>, ")
    repr_str += ("'Sphynx': <class 'test_registry.test_registry."
                 "<locals>.SphynxCat'>, ")
    repr_str += ("'Sphynx1': <class 'test_registry.test_registry."
                 "<locals>.SphynxCat'>, ")
    repr_str += ("'Sphynx2': <class 'test_registry.test_registry."
                 "<locals>.SphynxCat'>")
    repr_str += '})'
    assert repr(CATS) == repr_str

    # name type
    with pytest.raises(TypeError):
        CATS.register_module(name=7474741, module=SphynxCat)

    # the registered module should be a class
    with pytest.raises(TypeError):
        CATS.register_module(0)

    # can only decorate a class
    with pytest.raises(TypeError):

        @CATS.register_module()
        def some_method():
            pass

    # begin: test old APIs
    with pytest.warns(UserWarning):
        CATS.register_module(SphynxCat)
        assert CATS.get('SphynxCat').__name__ == 'SphynxCat'

    with pytest.warns(UserWarning):
        CATS.register_module(SphynxCat, force=True)
        assert CATS.get('SphynxCat').__name__ == 'SphynxCat'

    with pytest.warns(UserWarning):

        @CATS.register_module
        class NewCat:
            pass

        assert CATS.get('NewCat').__name__ == 'NewCat'

    with pytest.warns(UserWarning):
        CATS.deprecated_register_module(SphynxCat, force=True)
        assert CATS.get('SphynxCat').__name__ == 'SphynxCat'

    with pytest.warns(UserWarning):

        @CATS.deprecated_register_module
        class CuteCat:
            pass

        assert CATS.get('CuteCat').__name__ == 'CuteCat'

    with pytest.warns(UserWarning):

        @CATS.deprecated_register_module(force=True)
        class NewCat2:
            pass

        assert CATS.get('NewCat2').__name__ == 'NewCat2'
コード例 #6
0
ファイル: test_registry.py プロジェクト: zhouzaida/mmcv
def test_build_from_cfg():
    BACKBONES = mmcv.Registry('backbone')

    @BACKBONES.register_module()
    class ResNet:

        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    @BACKBONES.register_module()
    class ResNeXt:

        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    cfg = dict(type='ResNet', depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(type='ResNet', depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type='ResNeXt', depth=50, stages=3)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNeXt)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type=ResNet, depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    # type defined using default_args
    cfg = dict(depth=50)
    model = mmcv.build_from_cfg(
        cfg, BACKBONES, default_args=dict(type='ResNet'))
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    # not a registry
    with pytest.raises(TypeError):
        cfg = dict(type='VGG')
        model = mmcv.build_from_cfg(cfg, 'BACKBONES')

    # non-registered class
    with pytest.raises(KeyError):
        cfg = dict(type='VGG')
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # default_args must be a dict or None
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', depth=50)
        model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=1)

    # cfg['type'] should be a str or class
    with pytest.raises(TypeError):
        cfg = dict(type=1000)
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # cfg should contain the key "type"
    with pytest.raises(KeyError, match='must contain the key "type"'):
        cfg = dict(depth=50, stages=4)
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # cfg or default_args should contain the key "type"
    with pytest.raises(KeyError, match='must contain the key "type"'):
        cfg = dict(depth=50)
        model = mmcv.build_from_cfg(
            cfg, BACKBONES, default_args=dict(stages=4))

    # incorrect registry type
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', depth=50)
        model = mmcv.build_from_cfg(cfg, 'BACKBONES')

    # incorrect default_args type
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', depth=50)
        model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0)

    # incorrect arguments
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', non_existing_arg=50)
        model = mmcv.build_from_cfg(cfg, BACKBONES)