#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Model construction functions.""" import torch from fvcore.common.registry import Registry MODEL_REGISTRY = Registry("MODEL") MODEL_REGISTRY.__doc__ = """ Registry for video model. The registered object will be called with `obj(cfg)`. The call should return a `torch.nn.Module` object. """ def build_model(cfg, gpu_id=None): """ Builds the video model. Args: cfg (configs): configs that contains the hyper-parameters to build the backbone. Details can be seen in slowfast/config/defaults.py. gpu_id (Optional[int]): specify the gpu index to build model. """ if torch.cuda.is_available(): assert ( cfg.NUM_GPUS <= torch.cuda.device_count() ), "Cannot use more GPU devices than available" else: assert (
""" @author: Yuhao Cheng @contact: yuhao.cheng[at]outlook.com """ from fvcore.common.registry import Registry ENGINE_REGISTRY = Registry("ENGINE") ENGINE_REGISTRY.__doc__ = """ Registry for hook functional classes """
import logging from fvcore.common.registry import Registry from template_lib.utils import register_modules REGISTRY = Registry("MODEL_REGISTRY") # noqa F401 isort:skip MODEL_REGISTRY = REGISTRY REGISTRY.__doc__ = """ """ def _build(cfg, **kwargs): logging.getLogger('tl').info(f"Building {cfg.name} ...") register_modules(register_modules=cfg.get('register_modules', {})) ret = REGISTRY.get(cfg.name)(cfg=cfg, **kwargs) REGISTRY._obj_map.clear() return ret def build_model(cfg, **kwargs): """ Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. Note that it does not load any weights from ``cfg``. """ return _build(cfg, **kwargs)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. from fvcore.common.registry import Registry DATASET_REGISTRY = Registry("DATASET") DATASET_REGISTRY.__doc__ = """ Registry for dataset. The registered object will be called with `obj(cfg, split)`. The call should return a `torch.utils.data.Dataset` object. """ def build_dataset(dataset_name, cfg, split): """ Build a dataset, defined by `dataset_name`. Args: dataset_name (str): the name of the dataset to be constructed. cfg (CfgNode): configs. Details can be found in slowfast/config/defaults.py split (str): the split of the data loader. Options include `train`, `val`, and `test`. Returns: Dataset: a constructed dataset specified by dataset_name. """ # Capitalize the the first letter of the dataset_name since the dataset_name # in configs may be in lowercase but the name of dataset class should always # start with an uppercase letter. name = dataset_name.capitalize() return DATASET_REGISTRY.get(name)(cfg, split)
from fvcore.common.registry import Registry import torch.nn as nn BACKBONE_REGISTRY = Registry("BACKBONE") BACKBONE_REGISTRY.__doc__ = """ Registry for backbones, which extract feature maps from images The registered object must be a callable that accepts two arguments: 1. A :class:`detectron2.config.CfgNode` 2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification. It must returns an instance of :class:`Backbone`. """ def build_backbone(cfg): """ Build a backbone from `cfg.MODEL.BACKBONE.NAME`. Returns: an instance of :class:`Backbone` """ backbone_name = cfg.MODEL.BACKBONE.NAME backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg) return backbone def get_norm(cfg, out_channels, momentum=0.1): """ Args: norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; or a callable that takes a channel number and returns the normalization layer as a nn.Module.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from fvcore.common.registry import Registry DISCRIMINATOR_REGISTRY = Registry("DISCRIMINATOR_REGISTRY") # noqa F401 isort:skip DISCRIMINATOR_REGISTRY.__doc__ = """ """ GENERATOR_REGISTRY = Registry("GENERATOR_REGISTRY") # noqa F401 isort:skip GENERATOR_REGISTRY.__doc__ = """ """ def build_discriminator(cfg, **kwargs): """ Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. Note that it does not load any weights from ``cfg``. """ name = cfg.name return DISCRIMINATOR_REGISTRY.get(name)(cfg=cfg, **kwargs) def build_generator(cfg, **kwargs): """ Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. Note that it does not load any weights from ``cfg``. """ name = cfg.name return GENERATOR_REGISTRY.get(name)(cfg=cfg, **kwargs)
from fvcore.common.registry import Registry D2LAYER_REGISTRY = Registry("D2LAYER_REGISTRY") # noqa F401 isort:skip D2LAYER_REGISTRY.__doc__ = """ """ def build_d2layer(cfg, **kwargs): """ Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. Note that it does not load any weights from ``cfg``. """ name = cfg.name return D2LAYER_REGISTRY.get(name)(cfg=cfg, **kwargs)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from fvcore.common.registry import Registry DATASET_MAPPER_REGISTRY = Registry( "DATASET_MAPPER_REGISTRY") # noqa F401 isort:skip DATASET_MAPPER_REGISTRY.__doc__ = """ """ def build_dataset_mapper(cfg, **kwargs): """ Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. Note that it does not load any weights from ``cfg``. """ name = cfg.name if name.lower() == 'none': return None return DATASET_MAPPER_REGISTRY.get(name)(cfg=cfg, **kwargs)
""" @author: Yuhao Cheng @contact: yuhao.cheng[at]outlook.com """ from fvcore.common.registry import Registry LOSS_REGISTRY = Registry("LOSS") LOSS_REGISTRY.__doc__ = """ Registry for loss function classes """
""" @author: Yuhao Cheng @contact: yuhao.cheng[at]outlook.com """ from fvcore.common.registry import Registry DATASET_REGISTRY = Registry("DATASET") DATASET_REGISTRY.__doc__ = """ Registry for dataset classes """ DATASET_FACTORY_REGISTRY = Registry("DATASET_FACTORY") DATASET_FACTORY_REGISTRY.__doc__ = """ Registry for dataset factory classes """ EVAL_METHOD_REGISTRY = Registry("EVAL_METHOD") EVAL_METHOD_REGISTRY.__doc__ = """ Registry for eval method classes """
""" @author: Yuhao Cheng @contact: yuhao.cheng[at]outlook.com """ from fvcore.common.registry import Registry HOOK_REGISTRY = Registry("HOOK") HOOK_REGISTRY.__doc__ = """ Registry for hook functional classes """
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from fvcore.common.registry import Registry GAN_METRIC_REGISTRY = Registry("GAN_METRIC_REGISTRY") # noqa F401 isort:skip GAN_METRIC_REGISTRY.__doc__ = """ """ def build_GAN_metric_dict(cfg, **kwargs): """ Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. Note that it does not load any weights from ``cfg``. """ if cfg.GAN_metric.get('names') is None: return {} ret_dict = {} for name in cfg.GAN_metric.names: ret_dict.update( {name: GAN_METRIC_REGISTRY.get(name)(cfg=cfg, **kwargs)}) return ret_dict def build_GAN_metric(cfg, **kwargs): """ Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. Note that it does not load any weights from ``cfg``. """ metric = GAN_METRIC_REGISTRY.get(cfg.name)(cfg=cfg, **kwargs) return metric
from fvcore.common.registry import Registry OPS_REGISTRY = Registry("OPS_REGISTRY") # noqa F401 isort:skip OPS_REGISTRY.__doc__ = """ """ def build_ops(cfg, **kwargs): """ Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. Note that it does not load any weights from ``cfg``. """ name = cfg.name return OPS_REGISTRY.get(name)(cfg=cfg, **kwargs) LAYER_REGISTRY = Registry("LAYER_REGISTRY") # noqa F401 isort:skip LAYER_REGISTRY.__doc__ = """ """ def build_layer(cfg, **kwargs): """ Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. Note that it does not load any weights from ``cfg``. """ name = cfg.name return LAYER_REGISTRY.get(name)(cfg=cfg, **kwargs)
""" @author: Yuhao Cheng @contact: yuhao.cheng[at]outlook.com """ from fvcore.common.registry import Registry META_ARCH_REGISTRY = Registry("META_ARCH") META_ARCH_REGISTRY.__doc__ = """ Registry for meta-architectures, i.e. the whole model. The registered object will be called with `obj(cfg)` and expected to return a `nn.Module` object. """ BASE_ARCH_REGISTRY = Registry("BASE_ARCH") BASE_ARCH_REGISTRY.__doc__ = """ Registry for base-architectures, i.e. the backbone model. The registered object will be called with `obj(cfg)` and expected to return a `nn.Module` object. """ AUX_ARCH_REGISTRY = Registry("AUX_ARCH") AUX_ARCH_REGISTRY.__doc__ = """ Registry for auxiliary-architectures, i.e. the backbone model. The registered object will be called with `obj(cfg)` and expected to return a `nn.Module` object. """