예제 #1
0
 def __init__(self, transforms):
     assert isinstance(transforms, collections.abc.Sequence)
     self.transforms = []
     for transform in transforms:
         if isinstance(transform, dict):
             transform = build_from_cfg(transform, PIPELINES)
             self.transforms.append(transform)
         elif callable(transform):
             self.transforms.append(transform)
         else:
             raise TypeError('transform must be callable or a dict')
예제 #2
0
def test_build_from_cfg():
    BACKBONES = Registry("backbone")

    @BACKBONES.register_module()
    class ResNet:
        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    @BACKBONES.register_module()
    class ResNeXt:
        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    cfg = dict(type="ResNet", depth=50)
    model = build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(type="ResNet", depth=50)
    model = build_from_cfg(cfg, BACKBONES, default_args={"stages": 3})
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type="ResNeXt", depth=50, stages=3)
    model = build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNeXt)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type=ResNet, depth=50)
    model = build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    # type defined using default_args
    cfg = dict(depth=50)
    model = build_from_cfg(cfg, BACKBONES, default_args=dict(type="ResNet"))
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(depth=50)
    model = build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    # not a registry
    with pytest.raises(TypeError):
        cfg = dict(type="VGG")
        model = build_from_cfg(cfg, "BACKBONES")

    # non-registered class
    with pytest.raises(KeyError):
        cfg = dict(type="VGG")
        model = build_from_cfg(cfg, BACKBONES)

    # default_args must be a dict or None
    with pytest.raises(TypeError):
        cfg = dict(type="ResNet", depth=50)
        model = build_from_cfg(cfg, BACKBONES, default_args=1)

    # cfg['type'] should be a str or class
    with pytest.raises(TypeError):
        cfg = dict(type=1000)
        model = build_from_cfg(cfg, BACKBONES)

    # cfg should contain the key "type"
    with pytest.raises(KeyError, match='must contain the key "type"'):
        cfg = dict(depth=50, stages=4)
        model = build_from_cfg(cfg, BACKBONES)

    # cfg or default_args should contain the key "type"
    with pytest.raises(KeyError, match='must contain the key "type"'):
        cfg = dict(depth=50)
        model = build_from_cfg(cfg, BACKBONES, default_args=dict(stages=4))

    # incorrect registry type
    with pytest.raises(TypeError):
        cfg = dict(type="ResNet", depth=50)
        model = build_from_cfg(cfg, "BACKBONES")

    # incorrect default_args type
    with pytest.raises(TypeError):
        cfg = dict(type="ResNet", depth=50)
        model = build_from_cfg(cfg, BACKBONES, default_args=0)

    # incorrect arguments
    with pytest.raises(TypeError):
        cfg = dict(type="ResNet", non_existing_arg=50)
        model = build_from_cfg(cfg, BACKBONES)
예제 #3
0
def build_data_loader(cfg, default_args=None):
    dataset = build_from_cfg(cfg, DATA_LOADER, default_args)

    return dataset
예제 #4
0
def build(cfg, registry, default_args=None):
    return build_from_cfg(cfg, registry, default_args)
예제 #5
0
def build_dataset(cfg, default_args=None):
    dataset = build_from_cfg(cfg, DATASETS, default_args)

    return dataset
예제 #6
0
def __build_backbone_component(backbone_component_cfg):
    return build_from_cfg(backbone_component_cfg, BACKBONE_COMPONENT)
예제 #7
0
def __build_tracker_component(tracker_component_cfg):
    return build_from_cfg(tracker_component_cfg, TRACKER)
예제 #8
0
def __build_detector_component(detector_component_cfg):
    return build_from_cfg(detector_component_cfg, DETECTOR)
예제 #9
0
def __build_head_component(head_component_cfg):
    return build_from_cfg(head_component_cfg, HEAD)