Esempio n. 1
0
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import mobile_cv.common.misc.registry as registry

from . import modeldef_utils as mdu
from .modeldef_utils import e1, e6

MODEL_ARCH = registry.Registry("cls_arch_factory")

MODEL_ARCH_DEFAULT = {
    "default": {
        "blocks": [
            # [op, c, s, n, ...]
            # stage 0
            [("conv_k3", 32, 2, 1)],
            # stage 1
            [("ir_k3", 16, 1, 1, e1)],
            # stage 2
            [("ir_k3", 24, 2, 2, e6)],
            # stage 3
            [("ir_k3", 32, 2, 3, e6)],
            # stage 4
            [("ir_k3", 64, 2, 4, e6), ("ir_k3", 96, 1, 3, e6)],
            # stage 5
            [("ir_k3", 160, 2, 3, e6), ("ir_k3", 320, 1, 1, e6)],
            # stage 6
            [("conv_k1", 1280, 1, 1)],
        ]
    },
}
Esempio n. 2
0
    NaiveSyncBatchNorm1d,
    NaiveSyncBatchNorm3d,
)
from torch.quantization.fuse_modules import (
    fuse_conv_bn,
    fuse_conv_bn_relu,
    fuse_known_modules,
)

# Registry to get the names for fusing the supported module
# returns the list of list for the sub module to fuse
# func(
#    module: torch.nn.Module,
#    supported_fusing_types: Dict[str, List[torch.nn.Module]]
# ) -> List[List[str]]
FUSE_LIST_GETTER = registry.Registry("fuse_list_getter")

CONV_BN_RELU_SUPPORTED_FUSING_TYPES = {
    "conv": [nn.Conv1d, nn.Conv2d, nn.Conv3d],
    "bn": [
        nn.BatchNorm1d,
        nn.BatchNorm2d,
        nn.BatchNorm3d,
        nn.SyncBatchNorm,
        NaiveSyncBatchNorm,
        NaiveSyncBatchNorm1d,
        NaiveSyncBatchNorm3d,
    ],
    "relu": [nn.ReLU],
}
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Task factory to create a task object based builder name and arch name

To register a task builder, add the builder to __init__.py and
  from mobile_cv.model_zoo.tasks import task_factory
  @task_factory.TASK_FACTORY.register("classy")
  def classy_task(...):
      ...

To create a task, use
  from mobile_cv.model_zoo.tasks import task_factory
  task_factory.get(builder, ...)
"""

import mobile_cv.common.misc.registry as registry

TASK_FACTORY = registry.Registry("task_factory")


def get(builder, *args, **kwargs):
    task_builder = TASK_FACTORY.get(builder)
    return task_builder(*args, **kwargs)
Esempio n. 4
0
import mobile_cv.arch.utils.helper as hp
import mobile_cv.arch.utils.misc as utils_misc
import mobile_cv.common.misc.iter_utils as iu
import mobile_cv.common.misc.registry as registry
import torch
import torch.nn as nn
from mobile_cv.arch.layers import (
    GroupNorm,
    NaiveSyncBatchNorm,
    NaiveSyncBatchNorm1d,
    NaiveSyncBatchNorm3d,
)
from torch.nn.quantized.modules import FloatFunctional

BN_REGISTRY = registry.Registry("bn")
CONV_REGISTRY = registry.Registry("conv")
RELU_REGISTRY = registry.Registry("relu")
RESIDUAL_REGISTRY = registry.Registry("residual_connect")
UPSAMPLE_REGISTRY = registry.Registry("upsample")

logger = logging.getLogger(__name__)


class Identity(nn.Module):
    def __init__(self, in_channels, out_channels, stride, **kwargs):
        super().__init__()
        self.conv = None
        if in_channels != out_channels or stride != 1:
            self.conv = ConvBNRelu(
                in_channels,
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
FBNet model building blocks factory
"""

import mobile_cv.arch.utils.helper as hp
import mobile_cv.common.misc.registry as registry

from . import basic_blocks as bb
from . import irf_block

PRIMITIVES = registry.Registry("blocks_factory")

_PRIMITIVES = {
    "skip":
    lambda in_channels, out_channels, stride, **kwargs: bb.Identity(
        in_channels, out_channels, stride),
    "conv":
    lambda in_channels, out_channels, stride, **kwargs: bb.ConvBNRelu(
        in_channels, out_channels,
        **hp.merge(conv_args={"stride": stride}, kwargs=kwargs)),
    "conv_k1":
    lambda in_channels, out_channels, stride, **kwargs: bb.ConvBNRelu(
        in_channels, out_channels,
        **hp.merge(
            conv_args={
                "stride": stride,
                "kernel_size": 1,
                "padding": 0
            },
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Model zoo factory to create a model based on model builder name and arch name

To register a model builder, add the builder to __init__.py and
  from mobile_cv.model_zoo.models import model_zoo_factory
  @model_zoo_factory.MODEL_ZOO_FACTORY.register("fbnet_v2")
  def fbnet(...):
      ...

To create a model, use
  from mobile_cv.model_zoo.models import model_zoo_factory
  model_zoo_factory.get_model(builder, pretrained, ...)
"""

import mobile_cv.common.misc.registry as registry

MODEL_ZOO_FACTORY = registry.Registry("model_zoo_factory")


def get_model(builder, pretrained=False, progress=True, **kwargs):
    model_builder = MODEL_ZOO_FACTORY.get(builder)
    return model_builder(pretrained=pretrained, progress=progress, **kwargs)
Esempio n. 7
0
FBNet model basic building blocks
"""

import logging
import numbers

import torch.nn as nn
from torch.nn.quantized.modules import FloatFunctional

import mobile_cv.arch.layers.misc as layers_misc
import mobile_cv.arch.utils.helper as hp
import mobile_cv.arch.utils.misc as utils_misc
import mobile_cv.common.misc.registry as registry
from mobile_cv.arch.layers import GroupNorm, NaiveSyncBatchNorm, interpolate

CONV_REGISTRY = registry.Registry("conv")
BN_REGISTRY = registry.Registry("bn")
RELU_REGISTRY = registry.Registry("relu")
UPSAMPLE_REGISTRY = registry.Registry("upsample")


logger = logging.getLogger(__name__)


class Identity(nn.Module):
    def __init__(self, in_channels, out_channels, stride, **kwargs):
        super().__init__()
        self.conv = None
        if in_channels != out_channels or stride != 1:
            self.conv = ConvBNRelu(
                in_channels,
Esempio n. 8
0
import os
import typing

import mobile_cv.arch.utils.fuse_utils as fuse_utils
import mobile_cv.arch.utils.jit_utils as ju
import mobile_cv.arch.utils.quantize_utils as quantize_utils
import mobile_cv.common.misc.registry as registry
import mobile_cv.lut.lib.pt.flops_utils as flops_utils
import mobile_cv.model_zoo.tasks.task_factory as task_factory
import torch
from mobile_cv.model_zoo.models import model_utils
from torch.utils.mobile_optimizer import optimize_for_mobile

logger = logging.getLogger("model_zoo_tools.export")

ExportFactory = registry.Registry("ExportFactory")
DEFAULT_EXPORT_FORMATS = ["torchscript", "torchscript_int8"]


def parse_args(args_list=None):
    parser = argparse.ArgumentParser(description="Model zoo model exporter")
    parser.add_argument(
        "--task",
        type=str,
        default=None,
        help="Task name, if @ is inside the name, use the str after it as the "
        "path to import",
    )
    parser.add_argument("--task_args",
                        type=json.loads,
                        default={},
Esempio n. 9
0
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
FBNet model inverse residual building block
"""

import numbers

import torch.nn as nn

import mobile_cv.arch.utils.helper as hp
import mobile_cv.common.misc.registry as registry

from . import basic_blocks as bb

RESIDUAL_REGISTRY = registry.Registry("residual_connect")


def build_residual_connect(name,
                           in_channels,
                           out_channels,
                           stride,
                           drop_connect_rate=None,
                           **res_args):
    if name is None:
        return None
    if name == "default":
        assert isinstance(stride, (numbers.Number, tuple, list))
        if isinstance(stride, (tuple, list)):
            stride_one = all(x == 1 for x in stride)
        else: