예제 #1
0
파일: supernet.py 프로젝트: yinfupai/nni
    def __init__(self, lookup_table, num_points=106):
        """
        Parameters
        ----------
        lookup_table : class
            to manage the candidate ops, layer information and layer perf
        num_points : int
            the number of landmarks for prediction
        """
        super(PFLDInference, self).__init__()

        stage_names = [stage for stage in lookup_table.layer_num]
        stage_lnum = [lookup_table.layer_num[stage] for stage in stage_names]
        self.stem = StemBlock(init_ch=INIT_CH, bottleneck=False)

        self.block4_1 = MBBlock(INIT_CH, 32, stride=2, mid_ch=32)

        stages_0 = [
            mutables.LayerChoice(
                choice_blocks(
                    lookup_table.layer_configs[layer_id],
                    lookup_table.lut_ops[stage_names[0]],
                )) for layer_id in range(stage_lnum[0])
        ]
        stages_1 = [
            mutables.LayerChoice(
                choice_blocks(
                    lookup_table.layer_configs[layer_id],
                    lookup_table.lut_ops[stage_names[1]],
                ))
            for layer_id in range(stage_lnum[0], stage_lnum[0] + stage_lnum[1])
        ]
        blocks = stages_0 + stages_1
        self.blocks = nn.Sequential(*blocks)

        self.avg_pool1 = nn.Conv2d(INIT_CH,
                                   INIT_CH,
                                   9,
                                   8,
                                   1,
                                   groups=INIT_CH,
                                   bias=False)
        self.avg_pool2 = nn.Conv2d(32, 32, 3, 2, 1, groups=32, bias=False)

        self.block6_1 = nn.Conv2d(96 + INIT_CH, 64, 1, 1, 0, bias=False)
        self.block6_2 = MBBlock(64, 64, res=True, se=True, mid_ch=128)
        self.block6_3 = SeparableConv(64, 128, 1)

        self.conv7 = nn.Conv2d(128, 128, 7, 1, 0, groups=128, bias=False)
        self.fc = nn.Conv2d(128, num_points * 2, 1, 1, 0, bias=True)

        # init params
        self.init_params()
예제 #2
0
파일: model.py 프로젝트: JSong-Jia/nni-1
 def __init__(self, node_id, num_prev_nodes, channels,
              num_downsample_connect):
     super().__init__()
     self.ops = nn.ModuleList()
     choice_keys = []
     for i in range(num_prev_nodes):
         stride = 2 if i < num_downsample_connect else 1
         choice_keys.append("{}_p{}".format(node_id, i))
         self.ops.append(
             mutables.LayerChoice([
                 ops.PoolBN('max', channels, 3, stride, 1, affine=False),
                 ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
                 nn.Identity() if stride == 1 else ops.FactorizedReduce(
                     channels, channels, affine=False),
                 ops.SepConv(channels, channels, 3, stride, 1,
                             affine=False),
                 ops.SepConv(channels, channels, 5, stride, 2,
                             affine=False),
                 ops.DilConv(
                     channels, channels, 3, stride, 2, 2, affine=False),
                 ops.DilConv(
                     channels, channels, 5, stride, 4, 2, affine=False)
             ],
                                  key=choice_keys[-1]))
     self.drop_path = ops.DropPath()
     self.input_switch = mutables.InputChoice(
         choose_from=choice_keys,
         n_chosen=2,
         key="{}_switch".format(node_id))
예제 #3
0
파일: model.py 프로젝트: JSong-Jia/nni-1
    def __init__(self, key, prev_keys, hidden_units, choose_from_k,
                 cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask):
        super(Layer, self).__init__(key)

        def conv_shortcut(kernel_size):
            return ConvBN(kernel_size, hidden_units, hidden_units,
                          cnn_keep_prob, False, True)

        self.n_candidates = len(prev_keys)
        if self.n_candidates:
            self.prec = mutables.InputChoice(
                choose_from=prev_keys[-choose_from_k:], n_chosen=1)
        else:
            # first layer, skip input choice
            self.prec = None
        self.op = mutables.LayerChoice([
            conv_shortcut(1),
            conv_shortcut(3),
            conv_shortcut(5),
            conv_shortcut(7),
            AvgPool(3, False, True),
            MaxPool(3, False, True),
            RNN(hidden_units, lstm_keep_prob),
            Attention(hidden_units, 4, att_keep_prob, att_mask)
        ])
        if self.n_candidates:
            self.skipconnect = mutables.InputChoice(choose_from=prev_keys)
        else:
            self.skipconnect = None
        self.bn = BatchNorm(hidden_units, False, True)
예제 #4
0
    def _make_blocks(self, blocks, in_channels, channels):
        result = []
        for i in range(blocks):
            stride = 2 if i == 0 else 1
            inp = in_channels if i == 0 else channels
            oup = channels

            base_mid_channels = channels // 2
            mid_channels = int(base_mid_channels)  # prepare for scale
            choice_block = mutables.LayerChoice([
                ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine),
                ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine),
                ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine),
                ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine)
            ])
            result.append(choice_block)

            # find the corresponding flops
            flop_key = (inp, oup, mid_channels, self._feature_map_size, self._feature_map_size, stride)
            self._parsed_flops[choice_block.key] = [
                self._op_flops_dict["{}_stride_{}".format(k, stride)][flop_key] for k in self.block_keys
            ]
            if stride == 2:
                self._feature_map_size //= 2
        return result
예제 #5
0
    def __init__(self, key, prev_labels, in_filters, out_filters):
        super().__init__(key)
        self.in_filters = in_filters
        self.out_filters = out_filters
        self.mutable = mutables.LayerChoice([
#             ConvBranch(in_filters, out_filters, kernel_size=3, stride=1, separable=False),
#             ConvBranch(in_filters, out_filters, kernel_size=3, stride=1, separable=True),
            ConvBranch(in_filters, out_filters, kernel_size=5, stride=1, separable=False),
#             ConvBranch(in_filters, out_filters, kernel_size=5, stride=1, separable=True),
#             ConvBranch(in_filters, out_filters, kernel_size=7, stride=1, separable=False),
#             ConvBranch(in_filters, out_filters, kernel_size=7, stride=1, separable=True),
#             ConvBranch(in_filters, out_filters, kernel_size=9, stride=1, separable=False),
            ConvBranch(in_filters, out_filters, kernel_size=41, stride=1, separable=True),
#             ResidualConvBranch(in_filters, out_filters, kernel_size=3, stride=1, separable=False),
#             ResidualConvBranch(in_filters, out_filters, kernel_size=3, stride=1, separable=True),
#             ResidualConvBranch(in_filters, out_filters, kernel_size=5, stride=1, separable=False),
#             ResidualConvBranch(in_filters, out_filters, kernel_size=5, stride=1, separable=True),
#             ResidualConvBranch(in_filters, out_filters, kernel_size=7, stride=1, separable=False),
#             ResidualConvBranch(in_filters, out_filters, kernel_size=7, stride=1, separable=True),
#             ResidualConvBranch(in_filters, out_filters, kernel_size=9, stride=1, separable=False),
#             ResidualConvBranch(in_filters, out_filters, kernel_size=9, stride=1, separable=True),
#             PoolBranch('avg', in_filters, out_filters, kernel_size=3, stride=1),
#             PoolBranch('max', in_filters, out_filters, kernel_size=3, stride=1),
        ])
        if len(prev_labels) > 0:
            self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None)
        else:
            self.skipconnect = None
        self.batch_norm = nn.BatchNorm1d(out_filters, affine=False)
예제 #6
0
 def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
     super().__init__()
     self.ops = nn.ModuleList()
     choice_keys = []
     for i in range(num_prev_nodes):
         stride = 2 if i < num_downsample_connect else 1
         choice_keys.append("{}_p{}".format(node_id, i))
         self.ops.append(mutables.LayerChoice([ops.OPS[k](channels, stride, False) for k in ops.PRIMITIVES],
                                              key=choice_keys[-1]))
     self.drop_path = ops.DropPath()
     self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
예제 #7
0
 def __init__(self, cell_name, prev_labels, channels):
     super().__init__()
     self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True,
                                              key=cell_name + "_input")
     self.op_choice = mutables.LayerChoice([
         SepConvBN(channels, channels, 3, 1),
         SepConvBN(channels, channels, 5, 2),
         Pool("avg", 3, 1, 1),
         Pool("max", 3, 1, 1),
         nn.Identity()
     ], key=cell_name + "_op")
 def __init__(self, cell_name, prev_labels, channels, input_num, op_num,
              num_nodes):
     super().__init__()
     self.op_choice = mutables.LayerChoice([
         SepConvBN(channels, channels, 3, 1),
         SepConvBN(channels, channels, 5, 2),
         Pool("avg", 3, 1, 1),
         Pool("max", 3, 1, 1),
         nn.Identity()
     ],
                                           key=cell_name + "_op")
     self.input_num = input_num
     self.num_nodes = num_nodes
예제 #9
0
 def __init__(self, key, prev_labels, out_shape):
     super().__init__(key)
     self.mutable = mutables.LayerChoice([
         block1(out_shape),
         block2(out_shape),
         block3(out_shape),
         block4(out_shape),
         block5(out_shape),
     ])
     if len(prev_labels) > 0:
         self.skipconnect = mutables.InputChoice(choose_from=prev_labels,
                                                 n_chosen=None)
     else:
         self.skipconnect = None
     self.batch_norm = nn.BatchNorm1d(out_shape)
예제 #10
0
 def __init__(self, node_id, num_prev_nodes, out_shape):
     super().__init__()
     self.ops = nn.ModuleList()
     choice_keys = []
     for i in range(num_prev_nodes):
         choice_keys.append("{}_p{}".format(node_id, i))
         self.ops.append(
             mutables.LayerChoice(OrderedDict([
                 ("block1", block1(out_shape)),
                 ("block2", block2(out_shape)),
                 ("block3", block3(out_shape)),
                 ("block4", block4(out_shape)),
                 ("block5", block5(out_shape))
             ]), key=choice_keys[-1]))
     self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
예제 #11
0
파일: darts_cell.py 프로젝트: zctt00/nni
    def __init__(self, node_id, num_prev_nodes, channels,
                 num_downsample_connect):
        """
        builtin Darts Node structure

        Parameters
        ---
        node_id: str
        num_prev_nodes: int
            the number of previous nodes in this cell
        channels: int
            output channels
        num_downsample_connect: int
            downsample the input node if this cell is reduction cell
        """
        super().__init__()
        self.ops = nn.ModuleList()
        choice_keys = []
        for i in range(num_prev_nodes):
            stride = 2 if i < num_downsample_connect else 1
            choice_keys.append("{}_p{}".format(node_id, i))
            self.ops.append(
                mutables.LayerChoice(OrderedDict([
                    ("maxpool",
                     PoolBN('max', channels, 3, stride, 1, affine=False)),
                    ("avgpool",
                     PoolBN('avg', channels, 3, stride, 1, affine=False)),
                    ("skipconnect", nn.Identity() if stride == 1 else
                     FactorizedReduce(channels, channels, affine=False)),
                    ("sepconv3x3",
                     SepConv(channels, channels, 3, stride, 1, affine=False)),
                    ("sepconv5x5",
                     SepConv(channels, channels, 5, stride, 2, affine=False)),
                    ("dilconv3x3",
                     DilConv(channels, channels, 3, stride, 2, 2,
                             affine=False)),
                    ("dilconv5x5",
                     DilConv(channels, channels, 5, stride, 4, 2,
                             affine=False))
                ]),
                                     key=choice_keys[-1]))
        self.drop_path = DropPath()
        self.input_switch = mutables.InputChoice(
            choose_from=choice_keys,
            n_chosen=2,
            key="{}_switch".format(node_id))
예제 #12
0
 def __init__(self, key, prev_labels, in_filters, out_filters):
     super().__init__(key)
     self.in_filters = in_filters
     self.out_filters = out_filters
     self.mutable = mutables.LayerChoice([
         ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
         ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
         ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
         ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
         PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
         PoolBranch('max', in_filters, out_filters, 3, 1, 1)
     ])
     if prev_labels > 0:
         self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None)
     else:
         self.skipconnect = None
     self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
예제 #13
0
    def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
        '''
        Node("{}_n{}".format("reduce" if reduction else "normal", depth),
             depth, channels, 2 if reduction else 0)
        num_prev_nodes: 之前的节点个数
        '''

        super().__init__()
        self.ops = nn.ModuleList()
        choice_keys = []  # 记录 节点+边 组合的名称

        for i in range(num_prev_nodes):  # 枚举之前的节点
            stride = 2 if i < num_downsample_connect else 1
            # 统一设置stride
            # 如果是reduction cell, stride=2,
            # 如果是normal cell, stride=1
            choice_keys.append("{}_p{}".format(node_id, i))

            self.ops.append(
                mutables.LayerChoice(OrderedDict([
                    ("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
                    ("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
                    ("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(
                        channels, channels, affine=False)),
                    ("sepconv3x3", ops.SepConv(channels,
                                               channels, 3, stride, 1, affine=False)),
                    ("sepconv5x5", ops.SepConv(channels,
                                               channels, 5, stride, 2, affine=False)),
                    ("dilconv3x3", ops.DilConv(channels,
                                               channels, 3, stride, 2, 2, affine=False)),
                    ("dilconv5x5", ops.DilConv(channels,
                                               channels, 5, stride, 4, 2, affine=False))
                ]), key=choice_keys[-1]))

        self.drop_path = ops.DropPath()  # 以0.2的概率drop path

        self.input_switch = mutables.InputChoice(  # 控制连接方式, 维护choice_key就是为了这个使用
            choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
예제 #14
0
def _parse_torch_module_from_submodule_spec(
        deepcv_module: 'deepcv.meta.base_module.DeepcvModule',
        submodule_spec: Union[Dict, Type[torch.nn.Module], str,
                              Callable[..., torch.nn.Module]],
        submodule_pos: Union[int, str],
        subm_creators: SUBMODULE_CREATORS_DICT_T,
        default_submodule_prefix: str = '_submodule_',
        allow_mutable_layer_choices: bool = True
) -> Tuple[str, torch.nn.Module]:
    """ Defines a single submodule of `DeepcvModule` from its respective NN architecture spec.  
    TODO: Refactor `_parse_torch_module_from_submodule_spec` and underlying function to be independent of DeepcvModule (and, for ex., put those in a speratate `yaml_nn_spec_parsing.py` file) (Context which still depends on DeepcvModule instance: `deepcv_module._uses_nni_nas_mutables deepcv_module._uses_forward_callback_submodules deepcv_module._features_shapes deepcv_module._hp deepcv_module.HP_DEFAULTS`)  
    """
    from deepcv.meta.submodule_creators import ForwardCallbackSubmodule
    subm_name = default_submodule_prefix + str(submodule_pos)
    subm_name, params, subm_type = _subm_name_and_params_from_spec(
        submodule_spec,
        default_subm_name=subm_name,
        existing_subm_names=deepcv_module._submodules.keys())

    # Add global (hyper)parameters from `hp` to `params` (allows to define parameters like `act_fn`, `dropout_prob`, `batch_norm`, ... either globaly in `hp` or localy in `params` from submodule specs)
    # NOTE: In case a parameter is both specified in `deepcv_module._hp` globals and in `params` local submodule specs, `params` entries from submodule specs will allways override parameters from `hp`
    params_with_globals = {
        n: copy.deepcopy(v)
        for n, v in deepcv_module._hp.items() if n not in params
    }
    params_with_globals.update(params)

    if subm_type == yaml_tokens.NESTED_DEEPCV_MODULE:
        # Allow nested DeepCV sub-module (see deepcv/conf/base/parameters.yml for examples)
        module = type(deepcv_module)(
            input_shape=deepcv_module._features_shapes[-1],
            hp=params_with_globals,
            additional_subm_creators=subm_creators,
            extend_basic_subm_creators_dict=False,
            additional_init_logic=deepcv_module._additional_init_logic)
    elif subm_type == yaml_tokens.NAS_LAYER_CHOICE:
        deepcv_module._uses_nni_nas_mutables = True
        if not allow_mutable_layer_choices:
            raise ValueError(
                f'Error: nested LayerChoices are forbiden, cant specify a NNI NAS Mutable LayerChoice as a candidate of another LayerChoice ("{subm_type}").'
            )
        # List of alternative submodules: nni_mutables.LayerChoice + makes sure candidate submodules names can't be referenced : LayerChoice candidates may have names (OrderedDict instead of List) but references are only allowed on 'yaml_tokens.NAS_LAYER_CHOICE' global name)
        # for more details on `LayerChoice`, see https://nni.readthedocs.io/en/latest/NAS/NasReference.html#nni.nas.pytorch.mutables.LayerChoice
        if not isinstance(
                params, Dict[str, Any]
        ) or yaml_tokens.NAS_LAYER_CHOICE_CANDIDATES not in params or any([
                p not in {
                    yaml_tokens.NAS_LAYER_CHOICE_CANDIDATES,
                    yaml_tokens.NAS_LAYER_REDUCTION_FN,
                    yaml_tokens.NAS_MUTABLE_RETURN_MASK
                } for p in params.keys()
        ]):
            raise ValueError(
                f'Error: Parameters of a "{yaml_tokens.NAS_LAYER_CHOICE}" submodule specification must be a Dict which at least contains a `_candidates` parameter. '
                f'(And may eventually specify `{yaml_tokens.NAS_LAYER_REDUCTION_FN}`, `{yaml_tokens.NAS_MUTABLE_RETURN_MASK}` and/or `{yaml_tokens.SUBMODULE_NAME}` parameter(s)). NNI Mutable LayerChoice submodule params received: "{params}"'
            )
        prefix = f'{default_submodule_prefix}{submodule_pos}_candidate_'
        reduction = getattr(params, yaml_tokens.NAS_LAYER_REDUCTION_FN,
                            DEFAULT_LAYER_CHOICE_REDUCTION)
        return_mask = getattr(params, yaml_tokens.NAS_MUTABLE_RETURN_MASK,
                              False)

        # Parse candidates/alternative submodules from params (recursive call to `_parse_torch_module_from_submodule_spec`)
        candidate_refs, candidates = list(), OrderedDict()
        for j, candidate_spec in enumerate(
                params[yaml_tokens.NAS_LAYER_CHOICE_CANDIDATES]):
            candidate_name, candidate = _parse_torch_module_from_submodule_spec(
                deepcv_module,
                submodule_spec=candidate_spec,
                submodule_pos=j,
                subm_creators=subm_creators,
                default_submodule_prefix=prefix,
                allow_mutable_layer_choices=False)
            if getattr(candidate, 'referenced_submodules', None) is not None:
                candidate_refs.extend(candidate.referenced_submodules)
            else:
                # Ignore any tensor references if candidate doesn't need tensor references (no `referenced_submodules` attribute so subm is assumed to not take `referenced_submodules_out` argument)
                # NOTE: We assume that `referenced_submodules` attribute is reserved to modules which takes `referenced_submodules_out` arg (probably ForwardCallbackSubmodule)
                def _forward_monkey_patch(
                        *args,
                        referenced_submodules_out: Dict[
                            str, TENSOR_OR_SEQ_OF_TENSORS_T] = None,
                        **kwargs):
                    return candidate.forward(*args, **kwargs)

                candidate.forward = _forward_monkey_patch
            candidates[candidate_name] = candidate

        # Instanciate NNI NAS Mutable LayerChoice from parsed candidates and parameters
        module = nni_mutables.LayerChoice(op_candidates=candidates,
                                          reduction=reduction,
                                          return_mask=return_mask,
                                          key=subm_name)
        # Candidates tensor references are agregated and stored in parent LayerChoice so that `deepcv_module._submodule_references` only stores references of top-level submodules (not nested candidates)
        module.referenced_submodules = candidate_refs
    else:
        # Parses a regular NN submodule from specs. (either based on a submodule creator or directly a `torch.nn.Module` type or string identifier)
        if isinstance(subm_type, str):
            # Try to find sub-module creator or a torch.nn.Module's `__init__` function which matches `subm_type` identifier
            fn_or_type = subm_creators.get(subm_type)
            if not fn_or_type:
                # If we can't find suitable function in module_creators, we try to evaluate function name (allows external functions to be used to define model's modules)
                try:
                    fn_or_type = deepcv.utils.get_by_identifier(subm_type)
                except Exception as e:
                    raise RuntimeError(
                        f'Error: Could not locate module/function named "{subm_type}" given module creators: "{subm_creators.keys()}"'
                    ) from e
        else:
            # Specified submodule is assumed to be directly a `torch.nn.Module` or `Callable[..., torch.nn.Module]` type which will be instanciated with its respective parameters as possible arguments according to its `__init__` signature (`params` and global NN spec. parameters)
            fn_or_type = subm_type

        # Create layer/block submodule from its module_creator or its `torch.nn.Module.__init__()` method (`fn_or_type`)
        submodule_signature_params = inspect.signature(fn_or_type).parameters
        params_with_globals[
            'prev_shapes'] = deepcv_module._features_shapes  # Make possible to take a specific argument to get previous submodules output tensor shapes
        params_with_globals['input_shape'] = deepcv_module._features_shapes[
            -1]  # Make possible to take a specific argument to get input tensor shape(s) from previous submodule
        params_with_globals['input_shapes'] = deepcv_module._features_shapes[
            -1]  # Make possible to take a specific argument to get input tensor shape(s) from previous submodule
        provided_params = {
            n: v
            for n, v in params_with_globals.items()
            if n in submodule_signature_params
        }
        # Add `submodule_params` and ``prev_shapes` to `provided_params` if they are taken by submodule creator (or `torch.nn.Module` constructor)
        if 'submodule_params' in submodule_signature_params:
            # `submdule_params` parameter wont provide a param which is already provided through `provided_params` (either provided through `submdule_params` dict or directly as an argument named after this parameter `n`)
            provided_params['submodule_params'] = {
                n: v
                for n, v in params.items() if n not in provided_params
            }
        # Create submodule from its submdule creator or `torch.nn.Module` constructor
        module = fn_or_type(**provided_params)

        # Process submodule creators output `torch.nn.Module` so that `deepcv.meta.submodule_creators.ForwardCallbackSubmodule` submodules instances are handled in a specific way for output tensor references (e.g. dense/residual) and  NNI NAS Mutable InputChoice support. (these modules are defined by forward pass callbacks which may be fed with referenced sub-module(s) output and to previous sub-module output)
        if isinstance(module, ForwardCallbackSubmodule):
            # Submodules which are instances of `deepcv.meta.submodule_creators.ForwardCallbackSubmodule` are handled sperately allowing output tensor (residual/dense) references and NNI NAS Mutable InputChoice support. (`DeepcvModule`-specific `torch.nn.Module` defined from a callback called on forward passes)
            _setup_forward_callback_submodule(deepcv_module,
                                              subm_name,
                                              submodule_params=params,
                                              forward_callback_module=module)
        elif not isinstance(module, torch.nn.Module):
            raise RuntimeError(
                f'Error: Invalid sub-module creator function or type: '
                f'Must either be a `torch.nn.Module` (Type or string identifier of a Type) or a submodule creator which returns a `torch.nn.Module`.'
            )
    return subm_name, module
예제 #15
0
    def __call__(self, in_chs, model_block_args):
        """ Build the blocks
        Args:
            in_chs: Number of input-channels passed to first block
            model_block_args: A list of lists, outer list defines stages, inner
                list contains strings defining block configuration(s)
        Return:
             List of block stacks (each stack wrapped in nn.Sequential)
        """
        if self.verbose:
            logging.info('Building model trunk with %d stages...' %
                         len(model_block_args))
        self.in_chs = in_chs
        total_block_count = sum([len(x) for x in model_block_args])
        total_block_idx = 0
        current_stride = 2
        current_dilation = 1
        feature_idx = 0
        stages = []
        # outer list of block_args defines the stacks ('stages' by some conventions)
        for stage_idx, stage_block_args in enumerate(model_block_args):
            last_stack = stage_idx == (len(model_block_args) - 1)
            if self.verbose:
                self.logger.info('Stack: {}'.format(stage_idx))
            assert isinstance(stage_block_args, list)

            # blocks = []
            # each stack (stage) contains a list of block arguments
            for block_idx, block_args in enumerate(stage_block_args):
                last_block = block_idx == (len(stage_block_args) - 1)
                if self.verbose:
                    self.logger.info(' Block: {}'.format(block_idx))

                # Sort out stride, dilation, and feature extraction details
                assert block_args['stride'] in (1, 2)
                if block_idx >= 1:
                    # only the first block in any stack can have a stride > 1
                    block_args['stride'] = 1

                next_dilation = current_dilation
                if block_args['stride'] > 1:
                    next_output_stride = current_stride * block_args['stride']
                    if next_output_stride > self.output_stride:
                        next_dilation = current_dilation * block_args['stride']
                        block_args['stride'] = 1
                    else:
                        current_stride = next_output_stride
                block_args['dilation'] = current_dilation
                if next_dilation != current_dilation:
                    current_dilation = next_dilation

                if stage_idx == 0 or stage_idx == 6:
                    self.choice_num = 1
                else:
                    self.choice_num = len(self.choices)

                    if self.dil_conv:
                        self.choice_num += 2

                choice_blocks = []
                block_args_copy = deepcopy(block_args)
                if self.choice_num == 1:
                    # create the block
                    block = self._make_block(block_args, 0, total_block_idx,
                                             total_block_count)
                    choice_blocks.append(block)
                else:
                    for choice_idx, choice in enumerate(self.choices):
                        # create the block
                        block_args = deepcopy(block_args_copy)
                        block_args = modify_block_args(block_args, choice[0],
                                                       choice[1])
                        block = self._make_block(block_args, choice_idx,
                                                 total_block_idx,
                                                 total_block_count)
                        choice_blocks.append(block)
                    if self.dil_conv:
                        block_args = deepcopy(block_args_copy)
                        block_args = modify_block_args(block_args, 3, 0)
                        block = self._make_block(block_args,
                                                 self.choice_num - 2,
                                                 total_block_idx,
                                                 total_block_count,
                                                 resunit=self.resunit,
                                                 dil_conv=self.dil_conv)
                        choice_blocks.append(block)

                        block_args = deepcopy(block_args_copy)
                        block_args = modify_block_args(block_args, 5, 0)
                        block = self._make_block(block_args,
                                                 self.choice_num - 1,
                                                 total_block_idx,
                                                 total_block_count,
                                                 resunit=self.resunit,
                                                 dil_conv=self.dil_conv)
                        choice_blocks.append(block)

                    if self.resunit:
                        block = get_Bottleneck(block.conv_pw.in_channels,
                                               block.conv_pwl.out_channels,
                                               block.conv_dw.stride[0])
                        choice_blocks.append(block)

                choice_block = mutables.LayerChoice(choice_blocks)
                stages.append(choice_block)
                # create the block
                # block = self._make_block(block_args, total_block_idx, total_block_count)
                total_block_idx += 1  # incr global block idx (across all stacks)

            # stages.append(blocks)
        return stages