示例#1
0
    def test_registry(self) -> None:
        """
        Test registering and accessing objects in the Registry.
        """
        OBJECT_REGISTRY = Registry("OBJECT")

        @OBJECT_REGISTRY.register()
        class Object1:
            pass

        with self.assertRaises(Exception) as err:
            OBJECT_REGISTRY.register(Object1)
        self.assertTrue(
            "An object named 'Object1' was already registered in 'OBJECT' registry!"
            in str(err.exception)
        )

        self.assertEqual(OBJECT_REGISTRY.get("Object1"), Object1)

        with self.assertRaises(KeyError) as err:
            OBJECT_REGISTRY.get("Object2")
        self.assertTrue(
            "No object named 'Object2' found in 'OBJECT' registry!"
            in str(err.exception)
        )
示例#2
0
    print('in_channels=', in_channels)
    return EfficientNet_.from_pretrained(f'efficientnet-b{name}')


@UNET_ENCODE.register()
def wrn_22():
    def _wrn_22():
        "Wide ResNet with 22 layers."
        return WideResNet(num_groups=3, N=3, num_classes=10, k=6, drop_p=0.2)

    return nn.Sequential(*list(_wrn_22().children())[0])


for i in range(1, 8):
    UNET_ENCODE._do_register(f'efficientnet-b{i}', partial(efficientnet,
                                                           name=i))

if __name__ == '__main__':

    encode = UNET_ENCODE.get('wrn_22')
    print(encode())

    encode = UNET_ENCODE.get('densenet121')
    print(encode())

    encode = UNET_ENCODE.get('efficientnet')
    print(encode(4))

    encode = UNET_ENCODE.get('efficientnet-b2')
    print(encode())