示例#1
0
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

"""Model construction functions."""

import torch
from fvcore.common.registry import Registry

MODEL_REGISTRY = Registry("MODEL")
MODEL_REGISTRY.__doc__ = """
Registry for video model.

The registered object will be called with `obj(cfg)`.
The call should return a `torch.nn.Module` object.
"""


def build_model(cfg, gpu_id=None):
    """
    Builds the video model.
    Args:
        cfg (configs): configs that contains the hyper-parameters to build the
        backbone. Details can be seen in slowfast/config/defaults.py.
        gpu_id (Optional[int]): specify the gpu index to build model.
    """
    if torch.cuda.is_available():
        assert (
            cfg.NUM_GPUS <= torch.cuda.device_count()
        ), "Cannot use more GPU devices than available"
    else:
        assert (
示例#2
0
"""
@author:  Yuhao Cheng
@contact: yuhao.cheng[at]outlook.com
"""
from fvcore.common.registry import Registry

ENGINE_REGISTRY = Registry("ENGINE")
ENGINE_REGISTRY.__doc__ = """
    Registry for hook functional classes
"""
示例#3
0
import logging
from fvcore.common.registry import Registry

from template_lib.utils import register_modules


REGISTRY = Registry("MODEL_REGISTRY")  # noqa F401 isort:skip
MODEL_REGISTRY = REGISTRY
REGISTRY.__doc__ = """

"""

def _build(cfg, **kwargs):
    logging.getLogger('tl').info(f"Building {cfg.name} ...")
    register_modules(register_modules=cfg.get('register_modules', {}))
    ret = REGISTRY.get(cfg.name)(cfg=cfg, **kwargs)
    REGISTRY._obj_map.clear()
    return ret

def build_model(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    return _build(cfg, **kwargs)

示例#4
0
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from fvcore.common.registry import Registry

DATASET_REGISTRY = Registry("DATASET")
DATASET_REGISTRY.__doc__ = """
Registry for dataset.

The registered object will be called with `obj(cfg, split)`.
The call should return a `torch.utils.data.Dataset` object.
"""


def build_dataset(dataset_name, cfg, split):
    """
    Build a dataset, defined by `dataset_name`.
    Args:
        dataset_name (str): the name of the dataset to be constructed.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
        split (str): the split of the data loader. Options include `train`,
            `val`, and `test`.
    Returns:
        Dataset: a constructed dataset specified by dataset_name.
    """
    # Capitalize the the first letter of the dataset_name since the dataset_name
    # in configs may be in lowercase but the name of dataset class should always
    # start with an uppercase letter.
    name = dataset_name.capitalize()
    return DATASET_REGISTRY.get(name)(cfg, split)
示例#5
0
from fvcore.common.registry import Registry
import torch.nn as nn

BACKBONE_REGISTRY = Registry("BACKBONE")
BACKBONE_REGISTRY.__doc__ = """
Registry for backbones, which extract feature maps from images
The registered object must be a callable that accepts two arguments:
1. A :class:`detectron2.config.CfgNode`
2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification.
It must returns an instance of :class:`Backbone`.
"""


def build_backbone(cfg):
    """
    Build a backbone from `cfg.MODEL.BACKBONE.NAME`.
    Returns:
        an instance of :class:`Backbone`
    """

    backbone_name = cfg.MODEL.BACKBONE.NAME
    backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg)
    return backbone


def get_norm(cfg, out_channels, momentum=0.1):
    """
    Args:
        norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
            or a callable that takes a channel number and returns
            the normalization layer as a nn.Module.
示例#6
0
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from fvcore.common.registry import Registry

DISCRIMINATOR_REGISTRY = Registry("DISCRIMINATOR_REGISTRY")  # noqa F401 isort:skip
DISCRIMINATOR_REGISTRY.__doc__ = """

"""

GENERATOR_REGISTRY = Registry("GENERATOR_REGISTRY")  # noqa F401 isort:skip
GENERATOR_REGISTRY.__doc__ = """

"""


def build_discriminator(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    name = cfg.name
    return DISCRIMINATOR_REGISTRY.get(name)(cfg=cfg, **kwargs)


def build_generator(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    name = cfg.name
    return GENERATOR_REGISTRY.get(name)(cfg=cfg, **kwargs)
示例#7
0
from fvcore.common.registry import Registry

D2LAYER_REGISTRY = Registry("D2LAYER_REGISTRY")  # noqa F401 isort:skip
D2LAYER_REGISTRY.__doc__ = """
"""


def build_d2layer(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    name = cfg.name
    return D2LAYER_REGISTRY.get(name)(cfg=cfg, **kwargs)
示例#8
0
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from fvcore.common.registry import Registry

DATASET_MAPPER_REGISTRY = Registry(
    "DATASET_MAPPER_REGISTRY")  # noqa F401 isort:skip
DATASET_MAPPER_REGISTRY.__doc__ = """

"""


def build_dataset_mapper(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    name = cfg.name
    if name.lower() == 'none':
        return None
    return DATASET_MAPPER_REGISTRY.get(name)(cfg=cfg, **kwargs)
示例#9
0
"""
@author:  Yuhao Cheng
@contact: yuhao.cheng[at]outlook.com
"""
from fvcore.common.registry import Registry

LOSS_REGISTRY = Registry("LOSS")
LOSS_REGISTRY.__doc__ = """
    Registry for loss function classes
"""
"""
@author:  Yuhao Cheng
@contact: yuhao.cheng[at]outlook.com
"""
from fvcore.common.registry import Registry

DATASET_REGISTRY = Registry("DATASET")
DATASET_REGISTRY.__doc__ = """
    Registry for dataset classes
"""
DATASET_FACTORY_REGISTRY = Registry("DATASET_FACTORY")
DATASET_FACTORY_REGISTRY.__doc__ = """
    Registry for dataset factory classes
"""
EVAL_METHOD_REGISTRY = Registry("EVAL_METHOD")
EVAL_METHOD_REGISTRY.__doc__ = """
    Registry for eval method classes
"""
示例#11
0
"""
@author:  Yuhao Cheng
@contact: yuhao.cheng[at]outlook.com
"""
from fvcore.common.registry import Registry

HOOK_REGISTRY = Registry("HOOK")
HOOK_REGISTRY.__doc__ = """
    Registry for hook functional classes
"""
示例#12
0
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from fvcore.common.registry import Registry

GAN_METRIC_REGISTRY = Registry("GAN_METRIC_REGISTRY")  # noqa F401 isort:skip
GAN_METRIC_REGISTRY.__doc__ = """

"""


def build_GAN_metric_dict(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    if cfg.GAN_metric.get('names') is None:
        return {}
    ret_dict = {}
    for name in cfg.GAN_metric.names:
        ret_dict.update(
            {name: GAN_METRIC_REGISTRY.get(name)(cfg=cfg, **kwargs)})
    return ret_dict


def build_GAN_metric(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """

    metric = GAN_METRIC_REGISTRY.get(cfg.name)(cfg=cfg, **kwargs)
    return metric
示例#13
0
from fvcore.common.registry import Registry

OPS_REGISTRY = Registry("OPS_REGISTRY")  # noqa F401 isort:skip
OPS_REGISTRY.__doc__ = """
"""
def build_ops(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    name = cfg.name
    return OPS_REGISTRY.get(name)(cfg=cfg, **kwargs)


LAYER_REGISTRY = Registry("LAYER_REGISTRY")  # noqa F401 isort:skip
LAYER_REGISTRY.__doc__ = """
"""
def build_layer(cfg, **kwargs):
    """
    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
    Note that it does not load any weights from ``cfg``.
    """
    name = cfg.name
    return LAYER_REGISTRY.get(name)(cfg=cfg, **kwargs)
示例#14
0
"""
@author:  Yuhao Cheng
@contact: yuhao.cheng[at]outlook.com
"""
from fvcore.common.registry import Registry

META_ARCH_REGISTRY = Registry("META_ARCH")
META_ARCH_REGISTRY.__doc__ = """
    Registry for meta-architectures, i.e. the whole model.
    The registered object will be called with `obj(cfg)`
    and expected to return a `nn.Module` object.
"""

BASE_ARCH_REGISTRY = Registry("BASE_ARCH")
BASE_ARCH_REGISTRY.__doc__ = """
    Registry for base-architectures, i.e. the backbone model.
    The registered object will be called with `obj(cfg)`
    and expected to return a `nn.Module` object.
"""

AUX_ARCH_REGISTRY = Registry("AUX_ARCH")
AUX_ARCH_REGISTRY.__doc__ = """
    Registry for auxiliary-architectures, i.e. the backbone model.
    The registered object will be called with `obj(cfg)`
    and expected to return a `nn.Module` object.
"""