Beispiel #1
0
    def __init__(self,
                 search_space,
                 device,
                 genotypes,
                 num_classes=10,
                 dropout_rate=0.0,
                 dropblock_rate=0.0,
                 schedule_cfg=None):
        super(DenseGenotypeModel, self).__init__(schedule_cfg)

        self.search_space = search_space
        self.device = device
        assert isinstance(genotypes, str)
        self.genotypes = list(
            genotype_from_str(genotypes, self.search_space)._asdict().values())
        self.num_classes = num_classes

        # training
        self.dropout_rate = dropout_rate
        self.dropblock_rate = dropblock_rate

        self._num_blocks = self.search_space.num_dense_blocks
        # build model
        self.stem = nn.Conv2d(3, self.genotypes[0], kernel_size=3, padding=1)

        self.dense_blocks = []
        self.trans_blocks = []
        last_channel = self.genotypes[0]
        for i_block in range(self._num_blocks):
            growths = self.genotypes[1 + i_block * 2]
            self.dense_blocks.append(
                self._new_dense_block(last_channel, growths))
            last_channel = int(last_channel + np.sum(growths))
            if i_block != self._num_blocks - 1:
                out_c = self.genotypes[2 + i_block * 2]
                self.trans_blocks.append(
                    self._new_transition_block(last_channel, out_c))
                last_channel = out_c
        self.dense_blocks = nn.ModuleList(self.dense_blocks)
        self.trans_blocks = nn.ModuleList(self.trans_blocks)

        self.final_bn = nn.BatchNorm2d(last_channel)
        self.final_relu = nn.ReLU()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        if self.dropout_rate and self.dropout_rate > 0:
            self.dropout = nn.Dropout(p=self.dropout_rate)
        else:
            self.dropout = ops.Identity()

        self.classifier = nn.Linear(last_channel, self.num_classes)

        self.to(self.device)

        # for flops calculation
        self.total_flops = 0
        self._flops_calculated = False
        self.set_hook()
Beispiel #2
0
    def __init__(self, op_cls, search_space, layer_index, num_channels,
                 num_out_channels, prev_num_channels, stride, prev_strides,
                 use_preprocess, preprocess_op_type, use_shortcut,
                 shortcut_op_type, **op_kwargs):
        super(SharedCell, self).__init__()
        self.search_space = search_space
        self.stride = stride
        self.is_reduce = stride != 1
        self.num_channels = num_channels
        self.num_out_channels = num_out_channels
        self.layer_index = layer_index
        self.use_preprocess = use_preprocess
        self.preprocess_op_type = preprocess_op_type
        self.use_shortcut = use_shortcut
        self.shortcut_op_type = shortcut_op_type
        self.op_kwargs = op_kwargs

        self._steps = self.search_space.get_layer_num_steps(layer_index)
        self._num_init = self.search_space.num_init_nodes
        if not self.search_space.cellwise_primitives:
            # the same set of primitives for different cg group
            self._primitives = self.search_space.shared_primitives
        else:
            # different set of primitives for different cg group
            self._primitives = \
                self.search_space.cell_shared_primitives[self.search_space.cell_layout[layer_index]]

        # initialize self.concat_op, self._out_multiplier (only used for discrete super net)
        self.concat_op = ops.get_concat_op(self.search_space.concat_op)
        if not self.concat_op.is_elementwise:
            expect(
                not self.search_space.loose_end,
                "For shared weights weights manager, when non-elementwise concat op do not "
                "support loose-end search space")
            self._out_multipler = self._steps if not self.search_space.concat_nodes \
                                  else len(self.search_space.concat_nodes)
        else:
            # elementwise concat op. e.g. sum, mean
            self._out_multipler = 1

        self.preprocess_ops = nn.ModuleList()
        prev_strides = list(np.cumprod(list(reversed(prev_strides))))
        prev_strides.insert(0, 1)
        prev_strides = reversed(prev_strides[:len(prev_num_channels)])
        for prev_c, prev_s in zip(prev_num_channels, prev_strides):
            if not self.use_preprocess:
                # stride/channel not handled!
                self.preprocess_ops.append(ops.Identity())
                continue
            if self.preprocess_op_type is not None:
                # specificy other preprocess op
                preprocess = ops.get_op(self.preprocess_op_type)(
                    C=prev_c,
                    C_out=num_channels,
                    stride=int(prev_s),
                    affine=False)
            else:
                if prev_s > 1:
                    # need skip connection, and is not the connection from the input image
                    preprocess = ops.FactorizedReduce(C_in=prev_c,
                                                      C_out=num_channels,
                                                      stride=prev_s,
                                                      affine=False)
                else:  # prev_c == _steps * num_channels or inputs
                    preprocess = ops.ReLUConvBN(C_in=prev_c,
                                                C_out=num_channels,
                                                kernel_size=1,
                                                stride=1,
                                                padding=0,
                                                affine=False)
            self.preprocess_ops.append(preprocess)
        assert len(self.preprocess_ops) == self._num_init

        if self.use_shortcut:
            self.shortcut_reduction_op = ops.get_op(self.shortcut_op_type)(
                C=prev_num_channels[-1],
                C_out=self.num_out_channel(),
                stride=self.stride,
                affine=True)

        self.edges = defaultdict(dict)
        self.edge_mod = torch.nn.Module(
        )  # a stub wrapping module of all the edges
        for i_step in range(self._steps):
            to_ = i_step + self._num_init
            for from_ in range(to_):
                self.edges[from_][to_] = op_cls(
                    self.num_channels,
                    self.num_out_channels,
                    stride=self.stride if from_ < self._num_init else 1,
                    primitives=self._primitives,
                    **op_kwargs)
                self.edge_mod.add_module("f_{}_t_{}".format(from_, to_),
                                         self.edges[from_][to_])

        self._edge_name_pattern = re.compile("f_([0-9]+)_t_([0-9]+)")
Beispiel #3
0
    def __init__(self,
                 search_space,
                 device,
                 rollout_type,
                 cell_cls,
                 op_cls,
                 gpus=tuple(),
                 num_classes=10,
                 init_channels=16,
                 stem_multiplier=3,
                 max_grad_norm=5.0,
                 dropout_rate=0.1,
                 use_stem="conv_bn_3x3",
                 stem_stride=1,
                 stem_affine=True,
                 preprocess_op_type=None,
                 cell_use_preprocess=True,
                 cell_group_kwargs=None,
                 cell_use_shortcut=False,
                 cell_shortcut_op_type="skip_connect"):
        super(SharedNet, self).__init__(search_space, device, rollout_type)
        nn.Module.__init__(self)

        # optionally data parallelism in SharedNet
        self.gpus = gpus

        self.num_classes = num_classes
        # init channel number of the first cell layers,
        # x2 after every reduce cell
        self.init_channels = init_channels
        # channels of stem conv / init_channels
        self.stem_multiplier = stem_multiplier
        self.use_stem = use_stem
        # possible cell group kwargs
        self.cell_group_kwargs = cell_group_kwargs
        # possible inter-cell shortcut
        self.cell_use_shortcut = cell_use_shortcut
        self.cell_shortcut_op_type = cell_shortcut_op_type

        # training
        self.max_grad_norm = max_grad_norm
        self.dropout_rate = dropout_rate

        # search space configs
        self._num_init = self.search_space.num_init_nodes
        self._cell_layout = self.search_space.cell_layout
        self._reduce_cgs = self.search_space.reduce_cell_groups
        self._num_layers = self.search_space.num_layers

        ## initialize sub modules
        if not self.use_stem:
            c_stem = 3
            init_strides = [1] * self._num_init
        elif isinstance(self.use_stem, (list, tuple)):
            self.stems = []
            c_stem = self.stem_multiplier * self.init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stems.append(
                    ops.get_op(stem_type)(c_in,
                                          c_stem,
                                          stride=stem_stride,
                                          affine=stem_affine))
            self.stems = nn.ModuleList(self.stems)
            init_strides = [stem_stride] * self._num_init
        else:
            c_stem = self.stem_multiplier * self.init_channels
            self.stem = ops.get_op(self.use_stem)(3,
                                                  c_stem,
                                                  stride=stem_stride,
                                                  affine=stem_affine)
            init_strides = [1] * self._num_init

        self.cells = nn.ModuleList()
        num_channels = self.init_channels
        prev_num_channels = [c_stem] * self._num_init
        strides = [
            2 if self._is_reduce(i_layer) else 1
            for i_layer in range(self._num_layers)
        ]

        for i_layer, stride in enumerate(strides):
            if stride > 1:
                num_channels *= stride
            if cell_group_kwargs is not None:
                # support passing in different kwargs when instantializing
                # cell class for different cell groups
                kwargs = {
                    k: v
                    for k, v in cell_group_kwargs[
                        self._cell_layout[i_layer]].items()
                }
            else:
                kwargs = {}
            # A patch: Can specificy input/output channels by hand in configuration,
            # instead of relying on the default
            # "whenever stride/2, channelx2 and mapping with preprocess operations" assumption
            _num_channels = num_channels if "C_in" not in kwargs \
                            else kwargs.pop("C_in")
            _num_out_channels = num_channels if "C_out" not in kwargs \
                                else kwargs.pop("C_out")
            cell = cell_cls(op_cls,
                            self.search_space,
                            layer_index=i_layer,
                            num_channels=_num_channels,
                            num_out_channels=_num_out_channels,
                            prev_num_channels=tuple(prev_num_channels),
                            stride=stride,
                            prev_strides=init_strides + strides[:i_layer],
                            use_preprocess=cell_use_preprocess,
                            preprocess_op_type=preprocess_op_type,
                            use_shortcut=cell_use_shortcut,
                            shortcut_op_type=cell_shortcut_op_type,
                            **kwargs)
            prev_num_channel = cell.num_out_channel()
            prev_num_channels.append(prev_num_channel)
            prev_num_channels = prev_num_channels[1:]
            self.cells.append(cell)

        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        if self.dropout_rate and self.dropout_rate > 0:
            self.dropout = nn.Dropout(p=self.dropout_rate)
        else:
            self.dropout = ops.Identity()
        self.classifier = nn.Linear(prev_num_channel, self.num_classes)

        self.to(self.device)
    def __init__(
        self,
        search_space,
        device,
        rollout_type="dense_rob",
        gpus=tuple(),
        num_classes=10,
        init_channels=36,
        stem_multiplier=1,
        max_grad_norm=5.0,
        drop_rate=0.2,
        drop_out_rate=0.1,
        use_stem="conv_bn_3x3",
        stem_stride=1,
        stem_affine=True,
        candidate_eval_no_grad=False,  # need grad in eval to craft adv examples
            calib_bn_batch=0,
            calib_bn_num=0
    ):
        super(RobSharedNet, self).__init__(search_space, device, rollout_type)

        nn.Module.__init__(self)

        cell_cls = RobSharedCell
        op_cls = RobSharedOp
        # optionally data parallelism in SharedNet
        self.gpus = gpus

        self.search_space = search_space
        self.num_classes = num_classes
        self.device = device
        self.drop_rate = drop_rate
        self.drop_out_rate = drop_out_rate
        self.init_channels = init_channels

        # channels of stem conv / init_channels
        self.stem_multiplier = stem_multiplier
        self.use_stem = use_stem

        # training
        self.max_grad_norm = max_grad_norm

        # search space configs
        self._ops_choices = self.search_space.primitives
        self._num_layers = self.search_space.num_layers

        self._num_init = self.search_space.num_init_nodes
        self._num_layers = self.search_space.num_layers
        self._cell_layout = self.search_space.cell_layout

        self.calib_bn_batch = calib_bn_batch
        self.calib_bn_num = calib_bn_num
        if self.calib_bn_num > 0 and self.calib_bn_batch > 0:
            self.logger.warn("`calib_bn_num` and `calib_bn_batch` set simultaneously, "
                             "will use `calib_bn_num` only")

        ## initialize sub modules
        if not self.use_stem:
            c_stem = 3
            init_strides = [1] * self._num_init
        elif isinstance(self.use_stem, (list, tuple)):
            self.stems = []
            c_stem = self.stem_multiplier * self.init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stems.append(
                    ops.get_op(stem_type)(
                        c_in, c_stem, stride=stem_stride, affine=stem_affine
                    )
                )
            self.stems = nn.ModuleList(self.stems)
            init_strides = [stem_stride] * self._num_init
        else:
            c_stem = self.stem_multiplier * self.init_channels
            self.stem = ops.get_op(self.use_stem)(
                3, c_stem, stride=stem_stride, affine=stem_affine
            )
            init_strides = [1] * self._num_init

        self.cells = nn.ModuleList()
        num_channels = self.init_channels
        prev_num_channels = [c_stem] * self._num_init
        strides = [
            2 if self._is_reduce(i_layer) else 1 for i_layer in range(self._num_layers)
        ]

        for i_layer, stride in enumerate(strides):
            if stride > 1:
                num_channels *= stride
            num_out_channels = num_channels
            cell = cell_cls(
                op_cls,
                self.search_space,
                num_input_channels=prev_num_channels,
                num_out_channels=num_out_channels,
                stride=stride,
                prev_strides=init_strides + strides[:i_layer],
            )

            self.cells.append(cell)
            prev_num_channel = cell.num_out_channel()
            prev_num_channels.append(prev_num_channel)
            prev_num_channels = prev_num_channels[1:]

        self.lastact = nn.Identity()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        if self.drop_rate and self.drop_rate > 0:
            self.dropout = nn.Dropout(p=self.drop_rate)
        else:
            self.dropout = ops.Identity()
        self.classifier = nn.Linear(prev_num_channels[-1], self.num_classes)
        self.to(self.device)

        self.candidate_eval_no_grad = candidate_eval_no_grad
        self.assembled = 0
        self.candidate_map = weakref.WeakValueDictionary()

        self.set_hook()
        self._flops_calculated = False
        self.total_flops = 0
    def __init__(
        self,
        search_space,
        device,
        genotypes,
        num_classes=10,
        init_channels=36,
        stem_multiplier=1,
        dropout_rate=0.0,
        dropout_path_rate=0.0,
        use_stem="conv_bn_3x3",
        stem_stride=1,
        stem_affine=True,
        schedule_cfg=None,
    ):
        super(DenseRobFinalModel, self).__init__(schedule_cfg)

        self.search_space = search_space
        self.device = device
        assert isinstance(genotypes, str)
        genotypes = genotype_from_str(genotypes, self.search_space)
        self.arch_list = self.search_space.rollout_from_genotype(
            genotypes).arch

        self.num_classes = num_classes
        self.init_channels = init_channels
        self.stem_multiplier = stem_multiplier
        self.use_stem = use_stem

        # training
        self.dropout_rate = dropout_rate
        self.dropout_path_rate = dropout_path_rate

        # search space configs
        self._num_init = self.search_space.num_init_nodes
        self._num_layers = self.search_space.num_layers

        ## initialize sub modules
        if not self.use_stem:
            c_stem = 3
            init_strides = [1] * self._num_init
        elif isinstance(self.use_stem, (list, tuple)):
            self.stems = []
            c_stem = self.stem_multiplier * self.init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stems.append(
                    ops.get_op(stem_type)(c_in,
                                          c_stem,
                                          stride=stem_stride,
                                          affine=stem_affine))
            self.stems = nn.ModuleList(self.stems)
            init_strides = [stem_stride] * self._num_init
        else:
            c_stem = self.stem_multiplier * self.init_channels
            self.stem = ops.get_op(self.use_stem)(3,
                                                  c_stem,
                                                  stride=stem_stride,
                                                  affine=stem_affine)
            init_strides = [1] * self._num_init

        self.cells = nn.ModuleList()
        num_channels = self.init_channels
        prev_num_channels = [c_stem] * self._num_init
        strides = [
            2 if self._is_reduce(i_layer) else 1
            for i_layer in range(self._num_layers)
        ]

        for i_layer, stride in enumerate(strides):
            if stride > 1:
                num_channels *= stride
            num_out_channels = num_channels
            kwargs = {}
            cg_idx = self.search_space.cell_layout[i_layer]

            cell = DenseRobCell(
                self.search_space,
                self.arch_list[cg_idx],
                # num_channels=num_channels,
                num_input_channels=prev_num_channels,
                num_out_channels=num_out_channels,
                # prev_num_channels=tuple(prev_num_channels),
                prev_strides=init_strides + strides[:i_layer],
                stride=stride,
                **kwargs)

            prev_num_channel = cell.num_out_channel()
            prev_num_channels.append(prev_num_channel)
            prev_num_channels = prev_num_channels[1:]
            self.cells.append(cell)

        self.lastact = nn.Identity()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        if self.dropout_rate and self.dropout_rate > 0:
            self.dropout = nn.Dropout(p=self.dropout_rate)
        else:
            self.dropout = ops.Identity()
        self.classifier = nn.Linear(prev_num_channels[-1], self.num_classes)
        self.to(self.device)

        # for flops calculation
        self.total_flops = 0
        self._flops_calculated = False
        self.set_hook()
import re
from collections import defaultdict

import six
import torch
from torch import nn

from aw_nas import utils, ops
from aw_nas.common import genotype_from_str
from aw_nas.ops import register_primitive
from aw_nas.final.base import FinalModel

skip_connect_2 = lambda C, C_out, stride, affine: \
                 ops.FactorizedReduce(C, C_out, stride=stride, affine=affine) if stride == 2 \
                 else (ops.Identity() if C == C_out else nn.Conv2d(C, C_out, 1, 1, 0))

register_primitive("skip_connect_2", skip_connect_2)


class DenseRobFinalModel(FinalModel):
    NAME = "dense_rob_final_model"

    SCHEDULABLE_ATTRS = ["dropout_path_rate"]

    def __init__(self,
                 search_space,
                 device,
                 genotypes,
                 num_classes=10,
                 init_channels=36,
Beispiel #7
0
    def __init__(
        self,
        search_space,  # layer2
        device,
        genotypes,  # layer2
        micro_model_type="micro-dense-model",
        micro_model_cfg={},
        num_classes=10,
        init_channels=36,
        stem_multiplier=1,
        dropout_rate=0.0,
        dropout_path_rate=0.0,
        use_stem="conv_bn_3x3",
        stem_stride=1,
        stem_affine=True,
        auxiliary_head=False,
        auxiliary_cfg=None,
        schedule_cfg=None,
    ):
        super(MacroStagewiseFinalModel, self).__init__(schedule_cfg)

        self.macro_ss = search_space.macro_search_space
        self.micro_ss = search_space.micro_search_space
        self.device = device
        assert isinstance(genotypes, str)
        self.genotypes_str = genotypes
        self.macro_g, self.micro_g = genotype_from_str(genotypes, search_space)

        # micro model (cell) class
        micro_model_cls = FinalModel.get_class_(micro_model_type)  # cell type

        self.num_classes = num_classes
        self.init_channels = init_channels
        self.stem_multiplier = stem_multiplier
        self.stem_stride = stem_stride
        self.stem_affine = stem_affine
        self.use_stem = use_stem

        # training
        self.dropout_rate = dropout_rate
        self.dropout_path_rate = dropout_path_rate

        self.auxiliary_head = auxiliary_head

        self.overall_adj = self.macro_ss.parse_overall_adj(self.macro_g)
        self.layer_widths = [float(w) for w in self.macro_g.width.split(",")]

        self.micro_model_cfg = micro_model_cfg
        if "postprocess" in self.micro_model_cfg.keys():
            self.cell_use_postprocess = self.micro_model_cfg["postprocess"]
        else:
            self.cell_use_postprocess = False

        # sort channels out
        assert self.stem_multiplier == 1, "Cannot handle stem_multiplier != 1 now"
        self.input_channel_list = [self.init_channels]
        for i in range(1, self.macro_ss.num_layers):
            self.input_channel_list.append(
                self.input_channel_list[i - 1] *
                2 if self._is_reduce(i - 1) else self.input_channel_list[i -
                                                                         1])
        for i in range(self.macro_ss.num_layers):
            self.input_channel_list[i] = int(
                self.input_channel_list[i] * self.layer_widths[i]
                if not self._is_reduce(i) else self.input_channel_list[i] *
                self.layer_widths[i - 1])

        self.output_channel_list = self.input_channel_list[1:] + [
            self.input_channel_list[-1]
        ]

        # construct cells
        if not self.use_stem:
            raise NotImplementedError
            c_stem = 3
        elif isinstance(self.use_stem, (list, tuple)):
            raise NotImplementedError
            self.stems = []
            c_stem = self.stem_multiplier * self.init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stems.append(
                    ops.get_op(stem_type)(c_in,
                                          c_stem,
                                          stride=stem_stride,
                                          affine=stem_affine))
            self.stem = nn.Sequential(self.stems)
        else:
            self.stem = ops.get_op(self.use_stem)(3,
                                                  self.input_channel_list[0],
                                                  stride=stem_stride,
                                                  affine=stem_affine)

        self.extra_stem = ops.get_op("nor_conv_1x1")(
            self.input_channel_list[0],
            self.input_channel_list[0] * self.micro_ss.num_steps,
            stride=1,
            affine=True,
        )

        # For sink-connect, don't init all cells, just init connected cells

        connected_cells = []
        for cell_idx in range(1, self.macro_ss.num_layers + 2):
            if len(self.overall_adj[cell_idx].nonzero()[0]) > 0:
                connected_cells.append(self.overall_adj[cell_idx].nonzero()[0])
        # -1 to make the 1st element 0
        self.connected_cells = np.concatenate(connected_cells)[1:] - 1
        """
        ininitialize cells, only connected cells are initialized
        also use `use_next_stage_width` to handle the disalignment of width due to width search
        """
        self.cells = nn.ModuleList()
        self.micro_arch_list = self.micro_ss.rollout_from_genotype(
            self.micro_g).arch
        for i_layer in range(self.macro_ss.num_layers):
            stride = 2 if self._is_reduce(i_layer) else 1
            connected_is_reduce = [
                self._is_reduce(i) for i in self.connected_cells
            ]
            # the layer-idx to use next stage's width: the last cell before the redudction cell in each stage
            use_next_stage_width_layer_idx = self.connected_cells[
                np.argwhere(np.array(connected_is_reduce)).reshape(-1) - 1]
            reduction_layer_idx = self.connected_cells[np.argwhere(
                np.array(connected_is_reduce)
            ).reshape(
                -1)]  #  find reudction cells are the 1-th in connected cells
            if not self.cell_use_postprocess:
                next_stage_widths = (np.array(
                    self.output_channel_list)[self.macro_ss.stages_begin[1:]]
                                     // 2)  # preprocess, so no //2
            else:
                next_stage_widths = (
                    np.array(self.output_channel_list)[
                        self.macro_ss.stages_begin[1:]] // 2
                )  # the width to use for `ues_next_stage_width`, the reduction cell is of expansion 2, so //2
            use_next_stage_width = (
                next_stage_widths[np.argwhere(
                    use_next_stage_width_layer_idx == i_layer).reshape(-1)] if
                np.argwhere(use_next_stage_width_layer_idx == i_layer).size > 0
                else None)
            input_channel_list_n = np.array(self.input_channel_list)
            input_channel_list_n[
                reduction_layer_idx] = next_stage_widths  # input of the reduction should be half of the next stage's width

            cg_idx = self.macro_ss.cell_layout[i_layer]
            if i_layer not in self.connected_cells:
                continue
            # contruct micro cell
            cell = micro_model_cls(
                self.micro_ss,
                self.micro_arch_list[cg_idx],
                num_input_channels=int(
                    input_channel_list_n[i_layer]
                ),  # TODO: input_channel_list is of type: np.int64
                num_out_channels=self.output_channel_list[i_layer],
                stride=stride,
                use_next_stage_width=use_next_stage_width,
                is_last_cell=True
                if i_layer == self.connected_cells[-1] else False,
                is_first_cell=True
                if i_layer == self.connected_cells[0] else False,
                skip_cell=False,
                **micro_model_cfg)
            # assume non-reduce cell does not change channel number

            self.cells.append(cell)
            # add auxiliary head

        # connected_cells has 1 more element [0] than the self.cells
        if self.auxiliary_head:
            self.where_aux_head = self.connected_cells[(2 * len(self.cells)) //
                                                       3]
            extra_expansion_for_aux = (
                1 if self.cell_use_postprocess else self.micro_ss.num_steps
            )  # if use preprocess, aux head's input ch num should change accordingly
            # aux head is connected to last cell's output
            if auxiliary_head == "imagenet":
                self.auxiliary_net = AuxiliaryHeadImageNet(
                    input_channel_list_n[self.where_aux_head] *
                    extra_expansion_for_aux, num_classes,
                    **(auxiliary_cfg or {}))
            else:
                self.auxiliary_net = AuxiliaryHead(
                    input_channel_list_n[self.where_aux_head] *
                    extra_expansion_for_aux, num_classes,
                    **(auxiliary_cfg or {}))

        self.lastact = nn.Identity()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        if self.dropout_rate and self.dropout_rate > 0:
            self.dropout = nn.Dropout(p=self.dropout_rate)
        else:
            self.dropout = ops.Identity()

        if not self.cell_use_postprocess:
            self.classifier = nn.Linear(
                self.output_channel_list[-1] * self.micro_ss.num_steps,
                self.num_classes)
        else:
            self.classifier = nn.Linear(self.output_channel_list[-1],
                                        self.num_classes)
        self.to(self.device)

        # for flops calculation
        self.total_flops = 0
        self._flops_calculated = False
        self._set_hook()
Beispiel #8
0
    def __init__(
        self,
        search_space,  # layer2
        device,
        genotypes,  # layer2
        micro_model_type="micro-dense-model",
        micro_model_cfg={},
        num_classes=10,
        init_channels=36,
        stem_multiplier=1,
        dropout_rate=0.0,
        dropout_path_rate=0.0,
        use_stem="conv_bn_3x3",
        stem_stride=1,
        stem_affine=True,
        auxiliary_head=False,
        auxiliary_cfg=None,
        schedule_cfg=None,
    ):
        super(MacroStagewiseFinalModel, self).__init__(schedule_cfg)

        self.macro_ss = search_space.macro_search_space
        self.micro_ss = search_space.micro_search_space
        self.device = device
        assert isinstance(genotypes, str)
        self.genotypes_str = genotypes
        self.macro_g, self.micro_g = genotype_from_str(genotypes, search_space)

        # micro model (cell) class
        micro_model_cls = FinalModel.get_class_(micro_model_type)  # cell type

        self.num_classes = num_classes
        self.init_channels = init_channels
        self.stem_multiplier = stem_multiplier
        self.stem_stride = stem_stride
        self.stem_affine = stem_affine
        self.use_stem = use_stem

        # training
        self.dropout_rate = dropout_rate
        self.dropout_path_rate = dropout_path_rate

        self.auxiliary_head = auxiliary_head

        self.overall_adj = self.macro_ss.parse_overall_adj(self.macro_g)
        self.layer_widths = [float(w) for w in self.macro_g.width.split(",")]

        # sort channels out
        assert self.stem_multiplier == 1, "Cannot handle stem_multiplier != 1 now"
        self.input_channel_list = [self.init_channels]
        for i in range(1, self.macro_ss.num_layers):
            self.input_channel_list.append(
                self.input_channel_list[i - 1] *
                2 if self._is_reduce(i - 1) else self.input_channel_list[i -
                                                                         1])
        for i in range(self.macro_ss.num_layers):
            self.input_channel_list[i] = int(
                self.input_channel_list[i] * self.layer_widths[i]
                if not self._is_reduce(i) else self.input_channel_list[i] *
                self.layer_widths[i - 1])

        self.output_channel_list = self.input_channel_list[1:] + [
            self.input_channel_list[-1]
        ]

        # construct cells
        if not self.use_stem:
            raise NotImplementedError
            c_stem = 3
        elif isinstance(self.use_stem, (list, tuple)):
            raise NotImplementedError
            self.stems = []
            c_stem = self.stem_multiplier * self.init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stems.append(
                    ops.get_op(stem_type)(c_in,
                                          c_stem,
                                          stride=stem_stride,
                                          affine=stem_affine))
            self.stem = nn.Sequential(self.stems)
        else:
            self.stem = ops.get_op(self.use_stem)(3,
                                                  self.input_channel_list[0],
                                                  stride=stem_stride,
                                                  affine=stem_affine)

        self.cells = nn.ModuleList()
        self.micro_arch_list = self.micro_ss.rollout_from_genotype(
            self.micro_g).arch
        for i_layer in range(self.macro_ss.num_layers):
            # print(i_layer, self._is_reduce(i_layer))
            stride = 2 if self._is_reduce(i_layer) else 1
            cg_idx = self.macro_ss.cell_layout[i_layer]
            # contruct micro cell
            # FIXME: Currently MacroStageWiseFinalModel doesnot support postprocess = False
            micro_model_cfg["postprocess"] = True
            cell = micro_model_cls(
                self.micro_ss,
                self.micro_arch_list[cg_idx],
                num_input_channels=self.input_channel_list[i_layer],
                num_out_channels=self.output_channel_list[i_layer],
                stride=stride,
                **micro_model_cfg)
            # assume non-reduce cell does not change channel number
            self.cells.append(cell)
            # add auxiliary head
            if i_layer == (
                    2 * self.macro_ss.num_layers) // 3 and self.auxiliary_head:
                if auxiliary_head == "imagenet":
                    self.auxiliary_net = AuxiliaryHeadImageNet(
                        self.output_channel_list[i_layer], num_classes,
                        **(auxiliary_cfg or {}))
                else:
                    self.auxiliary_net = AuxiliaryHead(
                        self.output_channel_list[i_layer], num_classes,
                        **(auxiliary_cfg or {}))

        self.lastact = nn.Identity()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        if self.dropout_rate and self.dropout_rate > 0:
            self.dropout = nn.Dropout(p=self.dropout_rate)
        else:
            self.dropout = ops.Identity()
        self.classifier = nn.Linear(self.output_channel_list[-1],
                                    self.num_classes)
        self.to(self.device)

        # for flops calculation
        self.total_flops = 0
        self._flops_calculated = False
        self._set_hook()
Beispiel #9
0
    def __init__(self,
                 search_space,
                 device,
                 genotypes,
                 num_classes=10,
                 init_channels=36,
                 layer_channels=tuple(),
                 stem_multiplier=3,
                 dropout_rate=0.1,
                 dropout_path_rate=0.2,
                 auxiliary_head=False,
                 auxiliary_cfg=None,
                 use_stem="conv_bn_3x3",
                 stem_stride=1,
                 stem_affine=True,
                 no_fc=False,
                 cell_use_preprocess=True,
                 cell_pool_batchnorm=False,
                 cell_group_kwargs=None,
                 cell_independent_conn=False,
                 cell_preprocess_stride="skip_connect",
                 cell_preprocess_normal="relu_conv_bn_1x1",
                 schedule_cfg=None):
        super(CNNGenotypeModel, self).__init__(schedule_cfg)

        self.search_space = search_space
        self.device = device
        assert isinstance(genotypes, str)
        self.genotypes = list(
            genotype_from_str(genotypes, self.search_space)._asdict().values())
        self.genotypes_grouped = list(
            zip([
                group_and_sort_by_to_node(conns)
                for conns in self.genotypes[:self.search_space.num_cell_groups]
            ], self.genotypes[self.search_space.num_cell_groups:]))
        # self.genotypes_grouped = [group_and_sort_by_to_node(g[1]) for g in self.genotypes\
        #                           if "concat" not in g[0]]

        self.num_classes = num_classes
        self.init_channels = init_channels
        self.layer_channels = layer_channels
        self.stem_multiplier = stem_multiplier
        self.use_stem = use_stem
        self.cell_use_preprocess = cell_use_preprocess
        self.cell_group_kwargs = cell_group_kwargs
        self.cell_independent_conn = cell_independent_conn
        self.no_fc = no_fc

        # training
        self.dropout_rate = dropout_rate
        self.dropout_path_rate = dropout_path_rate
        self.auxiliary_head = auxiliary_head

        # search space configs
        self._num_init = self.search_space.num_init_nodes
        self._cell_layout = self.search_space.cell_layout
        self._reduce_cgs = self.search_space.reduce_cell_groups
        self._num_layers = self.search_space.num_layers
        expect(len(self.genotypes_grouped) == self.search_space.num_cell_groups,
               ("Config genotype cell group number({}) "
                "does not match search_space cell group number({})")\
               .format(len(self.genotypes_grouped), self.search_space.num_cell_groups))

        ## initialize sub modules
        if not self.use_stem:
            c_stem = 3
            init_strides = [1] * self._num_init
        elif isinstance(self.use_stem, (list, tuple)):
            self.stems = []
            c_stem = self.stem_multiplier * self.init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stems.append(
                    ops.get_op(stem_type)(c_in,
                                          c_stem,
                                          stride=stem_stride,
                                          affine=stem_affine))
            self.stems = nn.ModuleList(self.stems)
            init_strides = [stem_stride] * self._num_init
        else:
            c_stem = self.stem_multiplier * self.init_channels
            self.stem = ops.get_op(self.use_stem)(3,
                                                  c_stem,
                                                  stride=stem_stride,
                                                  affine=stem_affine)
            init_strides = [1] * self._num_init

        self.cells = nn.ModuleList()
        num_channels = self.init_channels
        prev_num_channels = [c_stem] * self._num_init
        strides = [
            2 if self._is_reduce(i_layer) else 1
            for i_layer in range(self._num_layers)
        ]
        if self.layer_channels:
            expect(len(self.layer_channels) == len(strides) + 1,
                   ("Config cell channels({}) does not match search_space num layers + 1 ({})"\
                    .format(len(self.layer_channels), self.search_space.num_layers + 1)),
                   ConfigException)
        for i_layer, stride in enumerate(strides):
            if self.layer_channels:
                # input and output channels of every layer is specified
                num_channels = self.layer_channels[i_layer]
                num_out_channels = self.layer_channels[i_layer + 1]
            else:
                if stride > 1:
                    num_channels *= stride
                num_out_channels = num_channels
            if cell_group_kwargs is not None:
                # support passing in different kwargs when instantializing
                # cell class for different cell groups
                # Can specificy input/output channels by hand in configuration,
                # instead of relying on the default
                # "whenever stride/2, channelx2 and mapping with preprocess operations" assumption
                kwargs = {
                    k: v
                    for k, v in cell_group_kwargs[
                        self._cell_layout[i_layer]].items()
                }
                if "C_in" in kwargs:
                    num_channels = kwargs.pop("C_in")
                if "C_out" in kwargs:
                    num_out_channels = kwargs.pop("C_out")
            else:
                kwargs = {}
            cg_idx = self.search_space.cell_layout[i_layer]

            cell = CNNGenotypeCell(self.search_space,
                                   self.genotypes_grouped[cg_idx],
                                   layer_index=i_layer,
                                   num_channels=num_channels,
                                   num_out_channels=num_out_channels,
                                   prev_num_channels=tuple(prev_num_channels),
                                   stride=stride,
                                   prev_strides=init_strides +
                                   strides[:i_layer],
                                   use_preprocess=cell_use_preprocess,
                                   pool_batchnorm=cell_pool_batchnorm,
                                   independent_conn=cell_independent_conn,
                                   preprocess_stride=cell_preprocess_stride,
                                   preprocess_normal=cell_preprocess_normal,
                                   **kwargs)
            # TODO: support specify concat explicitly
            prev_num_channel = cell.num_out_channel()
            prev_num_channels.append(prev_num_channel)
            prev_num_channels = prev_num_channels[1:]
            self.cells.append(cell)

            if i_layer == (2 * self._num_layers) // 3 and self.auxiliary_head:
                if auxiliary_head == "imagenet":
                    self.auxiliary_net = AuxiliaryHeadImageNet(
                        prev_num_channels[-1], num_classes,
                        **(auxiliary_cfg or {}))
                else:
                    self.auxiliary_net = AuxiliaryHead(prev_num_channels[-1],
                                                       num_classes,
                                                       **(auxiliary_cfg or {}))

        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        if self.dropout_rate and self.dropout_rate > 0:
            self.dropout = nn.Dropout(p=self.dropout_rate)
        else:
            self.dropout = ops.Identity()
        if self.no_fc:
            self.classifier = ops.Identity()
        else:
            self.classifier = nn.Linear(prev_num_channels[-1],
                                        self.num_classes)
        self.to(self.device)

        # for flops calculation
        self.total_flops = 0
        self._flops_calculated = False
        self.set_hook()
Beispiel #10
0
    def __init__(self, search_space, genotype_grouped, layer_index,
                 num_channels, num_out_channels, prev_num_channels, stride,
                 prev_strides, use_preprocess, pool_batchnorm,
                 independent_conn, preprocess_stride, preprocess_normal,
                 **op_kwargs):
        super(CNNGenotypeCell, self).__init__()
        self.search_space = search_space
        self.conns_grouped, self.concat_nodes = genotype_grouped
        self.stride = stride
        self.is_reduce = stride != 1
        self.num_channels = num_channels
        self.num_out_channels = num_out_channels
        self.layer_index = layer_index
        self.use_preprocess = use_preprocess
        self.pool_batchnorm = pool_batchnorm
        self.independent_conn = independent_conn
        self.op_kwargs = op_kwargs

        self._steps = self.search_space.get_layer_num_steps(layer_index)
        self._num_init = self.search_space.num_init_nodes
        self._primitives = self.search_space.shared_primitives

        # initialize self.concat_op, self._out_multiplier (only used for discrete super net)
        self.concat_op = ops.get_concat_op(self.search_space.concat_op)
        if not self.concat_op.is_elementwise:
            expect(
                not self.search_space.loose_end,
                "For shared weights weights manager, when non-elementwise concat op do not "
                "support loose-end search space")
            self._out_multipler = self._steps if not self.search_space.concat_nodes \
                                  else len(self.search_space.concat_nodes)
        else:
            # elementwise concat op. e.g. sum, mean
            self._out_multipler = 1

        self.preprocess_ops = nn.ModuleList()
        prev_strides = list(np.cumprod(list(reversed(prev_strides))))
        prev_strides.insert(0, 1)
        prev_strides = reversed(prev_strides[:len(prev_num_channels)])
        for prev_c, prev_s in zip(prev_num_channels, prev_strides):
            # print("cin: {}, cout: {}, stride: {}".format(prev_c, num_channels, prev_s))
            if not self.use_preprocess:
                # stride/channel not handled!
                self.preprocess_ops.append(ops.Identity())
                continue
            if prev_s > 1:
                # need skip connection, and is not the connection from the input image
                # ops.FactorizedReduce(C_in=prev_c,
                preprocess = ops.get_op(preprocess_stride)(C=prev_c,
                                                           C_out=num_channels,
                                                           stride=prev_s,
                                                           affine=True)
            else:  # prev_c == _steps * num_channels or inputs
                preprocess = ops.get_op(preprocess_normal)(C=prev_c,
                                                           C_out=num_channels,
                                                           stride=1,
                                                           affine=True)
            self.preprocess_ops.append(preprocess)
        assert len(self.preprocess_ops) == self._num_init

        self.edges = _defaultdict_3()
        self.edge_mod = torch.nn.Module(
        )  # a stub wrapping module of all the edges
        for _, conns in self.conns_grouped:
            for op_type, from_, to_ in conns:
                stride = self.stride if from_ < self._num_init else 1
                op = ops.get_op(op_type)(num_channels, num_out_channels,
                                         stride, True, **op_kwargs)
                if self.pool_batchnorm and "pool" in op_type:
                    op = nn.Sequential(
                        op, nn.BatchNorm2d(num_out_channels, affine=False))
                index = len(self.edges[from_][to_][op_type])
                if index == 0 or self.independent_conn:
                    # there is no this connection already established,
                    # or use indepdent connection for exactly the same (from, to, op_type)
                    self.edges[from_][to_][op_type][index] = op
                    self.edge_mod.add_module(
                        "f_{}_t_{}-{}-{}".format(from_, to_, op_type, index),
                        op)

        self._edge_name_pattern = re.compile(
            "f_([0-9]+)_t_([0-9]+)-([a-z0-9_-]+)-([0-9])")