Ejemplo n.º 1
0
    def test_registry(self) -> None:
        """
        Test registering and accessing objects in the Registry.
        """
        OBJECT_REGISTRY = Registry("OBJECT")

        @OBJECT_REGISTRY.register()
        class Object1:
            pass

        with self.assertRaises(Exception) as err:
            OBJECT_REGISTRY.register(Object1)
        self.assertTrue(
            "An object named 'Object1' was already registered in 'OBJECT' registry!"
            in str(err.exception)
        )

        self.assertEqual(OBJECT_REGISTRY.get("Object1"), Object1)

        with self.assertRaises(KeyError) as err:
            OBJECT_REGISTRY.get("Object2")
        self.assertTrue(
            "No object named 'Object2' found in 'OBJECT' registry!"
            in str(err.exception)
        )
Ejemplo n.º 2
0
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from fvcore.common.registry import Registry

DISCRIMINATOR_REGISTRY = Registry("DISCRIMINATOR_REGISTRY")  # noqa F401 isort:skip
DISCRIMINATOR_REGISTRY.__doc__ = """

"""

GENERATOR_REGISTRY = Registry("GENERATOR_REGISTRY")  # noqa F401 isort:skip
GENERATOR_REGISTRY.__doc__ = """

"""


def build_discriminator(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    name = cfg.name
    return DISCRIMINATOR_REGISTRY.get(name)(cfg=cfg, **kwargs)


def build_generator(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    name = cfg.name
    return GENERATOR_REGISTRY.get(name)(cfg=cfg, **kwargs)
Ejemplo n.º 3
0
"""
@author:  Yuhao Cheng
@contact: yuhao.cheng[at]outlook.com
"""
from fvcore.common.registry import Registry

LOSS_REGISTRY = Registry("LOSS")
LOSS_REGISTRY.__doc__ = """
    Registry for loss function classes
"""
Ejemplo n.º 4
0
from fvcore.common.registry import Registry

RESNEST_DATASETS_REGISTRY = Registry('RESNEST_DATASETS')


def get_dataset(dataset_name):
    return RESNEST_DATASETS_REGISTRY.get(dataset_name)
Ejemplo n.º 5
0
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

"""Model construction functions."""

import torch
from fvcore.common.registry import Registry

MODEL_REGISTRY = Registry("MODEL")
MODEL_REGISTRY.__doc__ = """
Registry for video model.

The registered object will be called with `obj(cfg)`.
The call should return a `torch.nn.Module` object.
"""


def build_model(cfg, gpu_id=None):
    """
    Builds the video model.
    Args:
        cfg (configs): configs that contains the hyper-parameters to build the
        backbone. Details can be seen in slowfast/config/defaults.py.
        gpu_id (Optional[int]): specify the gpu index to build model.
    """
    if torch.cuda.is_available():
        assert (
            cfg.NUM_GPUS <= torch.cuda.device_count()
        ), "Cannot use more GPU devices than available"
    else:
        assert (
Ejemplo n.º 6
0
        if BASE_KEY in cfg:
            cfg[BASE_KEY] = reroute_config_path(cfg[BASE_KEY])
        return cfg

    def mock_unsafe_load(f):
        cfg = _unsafe_load(f)
        if BASE_KEY in cfg:
            cfg[BASE_KEY] = reroute_config_path(cfg[BASE_KEY])
        return cfg

    with mock.patch("yaml.safe_load", side_effect=mock_safe_load):
        with mock.patch("yaml.unsafe_load", side_effect=mock_unsafe_load):
            yield


CONFIG_SCALING_METHOD_REGISTRY = Registry("CONFIG_SCALING_METHOD")


def auto_scale_world_size(cfg, new_world_size):
    """
    Usually the config file is written for a specific number of devices, this method
    scales the config (in-place!) according to the actual world size using the
    pre-registered scaling methods specified as cfg.SOLVER.AUTO_SCALING_METHODS.

    Note for registering scaling methods:
        - The method will only be called when scaling is needed. It won't be called
            if SOLVER.REFERENCE_WORLD_SIZE is 0 or equal to target world size. Thus
            cfg.SOLVER.REFERENCE_WORLD_SIZE will always be positive.
        - The method updates cfg in-place, no return is required.
        - No need for changing SOLVER.REFERENCE_WORLD_SIZE.
Ejemplo n.º 7
0
# os.environ["CUDA_VISIBLE_DEVICES"]="0"
from fastai.vision import *
from fastai.vision.models import WideResNet
from fvcore.common.registry import Registry

UNET_ENCODE = Registry("UNET_ENCODE")


@UNET_ENCODE.register()
def resnet18():
    return nn.Sequential(*list(models.resnet18(
        pretrained=True).children())[:-2])


@UNET_ENCODE.register()
def densenet121():
    return nn.Sequential(*list(models.densenet121(
        pretrained=True).children())[0])


@UNET_ENCODE.register()
def densenet169():
    return nn.Sequential(*list(models.densenet169(
        pretrained=True).children())[0])


@UNET_ENCODE.register()
def densenet201():
    return nn.Sequential(*list(models.densenet201(
        pretrained=True).children())[0])
Ejemplo n.º 8
0
from fvcore.common.registry import Registry
import torch.nn as nn

BACKBONE_REGISTRY = Registry("BACKBONE")
BACKBONE_REGISTRY.__doc__ = """
Registry for backbones, which extract feature maps from images
The registered object must be a callable that accepts two arguments:
1. A :class:`detectron2.config.CfgNode`
2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification.
It must returns an instance of :class:`Backbone`.
"""


def build_backbone(cfg):
    """
    Build a backbone from `cfg.MODEL.BACKBONE.NAME`.
    Returns:
        an instance of :class:`Backbone`
    """

    backbone_name = cfg.MODEL.BACKBONE.NAME
    backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg)
    return backbone


def get_norm(cfg, out_channels, momentum=0.1):
    """
    Args:
        norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
            or a callable that takes a channel number and returns
            the normalization layer as a nn.Module.
Ejemplo n.º 9
0
import contextlib
import copy
import logging
from typing import List

import mock
import yaml
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
from detectron2.config import CfgNode as _CfgNode
from fvcore.common.registry import Registry

from .utils import reroute_config_path

logger = logging.getLogger(__name__)

CONFIG_CUSTOM_PARSE_REGISTRY = Registry("CONFIG_CUSTOM_PARSE")
DEFAULTS_GENERATOR_KEY = "_DEFAULTS_"


def _opts_to_dict(opts: List[str]):
    ret = {}
    for full_key, v in zip(opts[0::2], opts[1::2]):
        keys = full_key.split(".")
        cur = ret
        for key in keys[:-1]:
            if key not in cur:
                cur[key] = {}
            cur = cur[key]
        cur[keys[-1]] = v
    return ret
Ejemplo n.º 10
0
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: [email protected]
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch
from torchvision.transforms import *
from .transforms import *
from fvcore.common.registry import Registry

RESNEST_TRANSFORMS_REGISTRY = Registry('RESNEST_TRANSFORMS')


def get_transform(dataset_name):
    return RESNEST_TRANSFORMS_REGISTRY.get(dataset_name.lower())


@RESNEST_TRANSFORMS_REGISTRY.register()
def imagenet(base_size=None, crop_size=224, rand_aug=False):
    normalize = Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
    base_size = base_size if base_size is not None else int(1.0 * crop_size /
                                                            0.875)
    train_transforms = []
    val_transforms = []
    if rand_aug:
        from .autoaug import RandAugment
        train_transforms.append(RandAugment(2, 12))
Ejemplo n.º 11
0
"""
@author:  Yuhao Cheng
@contact: yuhao.cheng[at]outlook.com
"""
from fvcore.common.registry import Registry

HOOK_REGISTRY = Registry("HOOK")
HOOK_REGISTRY.__doc__ = """
    Registry for hook functional classes
"""
Ejemplo n.º 12
0
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from fvcore.common.registry import Registry

GAN_METRIC_REGISTRY = Registry("GAN_METRIC_REGISTRY")  # noqa F401 isort:skip
GAN_METRIC_REGISTRY.__doc__ = """

"""


def build_GAN_metric_dict(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    if cfg.GAN_metric.get('names') is None:
        return {}
    ret_dict = {}
    for name in cfg.GAN_metric.names:
        ret_dict.update(
            {name: GAN_METRIC_REGISTRY.get(name)(cfg=cfg, **kwargs)})
    return ret_dict


def build_GAN_metric(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """

    metric = GAN_METRIC_REGISTRY.get(cfg.name)(cfg=cfg, **kwargs)
    return metric
Ejemplo n.º 13
0
from fvcore.common.registry import Registry

RESNEST_MODELS_REGISTRY = Registry('RESNEST_MODELS')


def get_model(model_name):
    return RESNEST_MODELS_REGISTRY.get(model_name)
Ejemplo n.º 14
0
from fvcore.common.registry import Registry

OPS_REGISTRY = Registry("OPS_REGISTRY")  # noqa F401 isort:skip
OPS_REGISTRY.__doc__ = """
"""
def build_ops(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    name = cfg.name
    return OPS_REGISTRY.get(name)(cfg=cfg, **kwargs)


LAYER_REGISTRY = Registry("LAYER_REGISTRY")  # noqa F401 isort:skip
LAYER_REGISTRY.__doc__ = """
"""
def build_layer(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    name = cfg.name
    return LAYER_REGISTRY.get(name)(cfg=cfg, **kwargs)
Ejemplo n.º 15
0
from fvcore.common.registry import Registry

D2LAYER_REGISTRY = Registry("D2LAYER_REGISTRY")  # noqa F401 isort:skip
D2LAYER_REGISTRY.__doc__ = """
"""


def build_d2layer(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    name = cfg.name
    return D2LAYER_REGISTRY.get(name)(cfg=cfg, **kwargs)
Ejemplo n.º 16
0
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from fvcore.common.registry import Registry

DATASET_REGISTRY = Registry("DATASET")
DATASET_REGISTRY.__doc__ = """
Registry for dataset.

The registered object will be called with `obj(cfg, split)`.
The call should return a `torch.utils.data.Dataset` object.
"""


def build_dataset(dataset_name, cfg, split):
    """
    Build a dataset, defined by `dataset_name`.
    Args:
        dataset_name (str): the name of the dataset to be constructed.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
        split (str): the split of the data loader. Options include `train`,
            `val`, and `test`.
    Returns:
        Dataset: a constructed dataset specified by dataset_name.
    """
    # Capitalize the the first letter of the dataset_name since the dataset_name
    # in configs may be in lowercase but the name of dataset class should always
    # start with an uppercase letter.
    name = dataset_name.capitalize()
    return DATASET_REGISTRY.get(name)(cfg, split)
Ejemplo n.º 17
0
# def build_optimizer(cfg, **kwargs):
#     """
#     Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
#     Note that it does not load any weights from ``cfg``.
#     """
#     return OPTIMIZER_REGISTRY.get(cfg.name)(cfg, **kwargs)
#
#
#
import logging
from fvcore.common.registry import Registry

from template_lib.utils import register_modules


REGISTRY = Registry("OPTIMIZER_REGISTRY")  # noqa F401 isort:skip
OPTIMIZER_REGISTRY = REGISTRY
REGISTRY.__doc__ = """

"""

def _build(cfg, **kwargs):
    logging.getLogger('tl').info(f"Building {cfg.name} ...")
    register_modules(register_modules=cfg.get('register_modules', {}))
    ret = REGISTRY.get(cfg.name)(cfg=cfg, **kwargs)
    REGISTRY._obj_map.clear()
    return ret

def build_optimizer(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
Ejemplo n.º 18
0
"""
@author:  Yuhao Cheng
@contact: yuhao.cheng[at]outlook.com
"""
from fvcore.common.registry import Registry

ENGINE_REGISTRY = Registry("ENGINE")
ENGINE_REGISTRY.__doc__ = """
    Registry for hook functional classes
"""
Ejemplo n.º 19
0
"""
@author:  Yuhao Cheng
@contact: yuhao.cheng[at]outlook.com
"""
from fvcore.common.registry import Registry

DATASET_REGISTRY = Registry("DATASET")
DATASET_REGISTRY.__doc__ = """
    Registry for dataset classes
"""
DATASET_FACTORY_REGISTRY = Registry("DATASET_FACTORY")
DATASET_FACTORY_REGISTRY.__doc__ = """
    Registry for dataset factory classes
"""
EVAL_METHOD_REGISTRY = Registry("EVAL_METHOD")
EVAL_METHOD_REGISTRY.__doc__ = """
    Registry for eval method classes
"""
Ejemplo n.º 20
0
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from fvcore.common.registry import Registry

DATASET_MAPPER_REGISTRY = Registry(
    "DATASET_MAPPER_REGISTRY")  # noqa F401 isort:skip
DATASET_MAPPER_REGISTRY.__doc__ = """

"""


def build_dataset_mapper(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    name = cfg.name
    if name.lower() == 'none':
        return None
    return DATASET_MAPPER_REGISTRY.get(name)(cfg=cfg, **kwargs)
Ejemplo n.º 21
0
#
#
# def build_d2distributions(cfg, **kwargs):
#     """
#     Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
#     Note that it does not load any weights from ``cfg``.
#     """
#     return D2DISTRIBUTIONS_REGISTRY.get(cfg.name)(cfg, **kwargs)
#

import logging
from fvcore.common.registry import Registry

from template_lib.utils import register_modules

REGISTRY = Registry("DISTRIBUTIONS_REGISTRY")  # noqa F401 isort:skip
DISTRIBUTIONS_REGISTRY = REGISTRY
REGISTRY.__doc__ = """

"""


def _build(cfg, **kwargs):
    logging.getLogger('tl').info(f"Building {cfg.name} ...")
    register_modules(register_modules=cfg.get('register_modules', {}))
    ret = REGISTRY.get(cfg.name)(cfg=cfg, **kwargs)
    REGISTRY._obj_map.clear()
    return ret


def build_distributions(cfg, **kwargs):
Ejemplo n.º 22
0
import logging
from fvcore.common.registry import Registry

from template_lib.utils import register_modules


REGISTRY = Registry("MODEL_REGISTRY")  # noqa F401 isort:skip
MODEL_REGISTRY = REGISTRY
REGISTRY.__doc__ = """

"""

def _build(cfg, **kwargs):
    logging.getLogger('tl').info(f"Building {cfg.name} ...")
    register_modules(register_modules=cfg.get('register_modules', {}))
    ret = REGISTRY.get(cfg.name)(cfg=cfg, **kwargs)
    REGISTRY._obj_map.clear()
    return ret

def build_model(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    return _build(cfg, **kwargs)

"""
model construction function
"""
import torch
import torch.nn as nn
from fvcore.common.registry import Registry
from torch.nn import init
MODEL_REGISTRY = Registry("MODEL")


def weights_init_kaiming(m):
    """
    kaiming init
    :param m:
    :return:
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
        init.constant_(m.bias.data, 0.0)
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_xavier(m):
    """

    :param m:
Ejemplo n.º 24
0
"""
model construction function
"""
import  torch
import torch.nn as nn
from  fvcore.common.registry import  Registry
from torch.nn import init
DATASET_REGISTRY=Registry("DATASET")


def build_dataset(dataset_name,mode,cfg,):
    """

    :param cfg:
    :param dataset_name: avenue
    :param mode:  train /test
    :return:
    """
    # print("MODEL_REGISTRY", MODEL_REGISTRY.__dict__)
    name=dataset_name.capitalize()
    # init model  with xavier

    return DATASET_REGISTRY.get(name)(cfg,mode)



if __name__=="__main__":
    print("dataset register")
Ejemplo n.º 25
0
"""
@author:  Yuhao Cheng
@contact: yuhao.cheng[at]outlook.com
"""
from fvcore.common.registry import Registry

META_ARCH_REGISTRY = Registry("META_ARCH")
META_ARCH_REGISTRY.__doc__ = """
    Registry for meta-architectures, i.e. the whole model.
    The registered object will be called with `obj(cfg)`
    and expected to return a `nn.Module` object.
"""

BASE_ARCH_REGISTRY = Registry("BASE_ARCH")
BASE_ARCH_REGISTRY.__doc__ = """
    Registry for base-architectures, i.e. the backbone model.
    The registered object will be called with `obj(cfg)`
    and expected to return a `nn.Module` object.
"""

AUX_ARCH_REGISTRY = Registry("AUX_ARCH")
AUX_ARCH_REGISTRY.__doc__ = """
    Registry for auxiliary-architectures, i.e. the backbone model.
    The registered object will be called with `obj(cfg)`
    and expected to return a `nn.Module` object.
"""