Esempio n. 1
0
    def __init__(
        self,
        search_space,
        cell_arch,
        num_input_channels,
        num_out_channels,
        stride,
        prev_strides,
    ):
        super(DenseRobCell, self).__init__()

        self.search_space = search_space
        self.arch = cell_arch
        self.stride = stride
        self.is_reduce = stride != 1
        self.num_input_channels = num_input_channels
        self.num_out_channels = num_out_channels
        self.num_init_nodes = self.search_space.num_init_nodes

        self.preprocess_ops = nn.ModuleList()
        prev_strides = prev_strides[-self.num_init_nodes:]
        prev_strides = list(np.cumprod(list(reversed(prev_strides))))
        prev_strides.insert(0, 1)
        prev_strides = reversed(prev_strides[:len(num_input_channels)])
        for prev_c, prev_s in zip(num_input_channels, prev_strides):
            preprocess = ops.get_op("skip_connect_2")(C=prev_c,
                                                      C_out=num_out_channels,
                                                      stride=prev_s,
                                                      affine=True)
            self.preprocess_ops.append(preprocess)

        self._num_nodes = self.search_space._num_nodes
        self._primitives = self.search_space.primitives
        self.num_init_nodes = self.search_space.num_init_nodes

        self.edges = defaultdict(dict)
        self.edge_mod = torch.nn.Module(
        )  # a stub wrapping module of all the edges
        for from_ in range(self._num_nodes):
            for to_ in range(max(self.num_init_nodes, from_ + 1),
                             self._num_nodes):
                self.edges[from_][to_] = ops.get_op(
                    self._primitives[int(self.arch[to_][from_])]
                )(
                    # self.num_input_channels[from_] \
                    # if from_ < self.num_init_nodes else self.num_out_channels,
                    self.num_out_channels,
                    self.num_out_channels,
                    stride=self.stride if from_ < self.num_init_nodes else 1,
                    affine=False,
                )
                self.edge_mod.add_module("f_{}_t_{}".format(from_, to_),
                                         self.edges[from_][to_])
        self._edge_name_pattern = re.compile("f_([0-9]+)_t_([0-9]+)")
Esempio n. 2
0
    def __init__(
        self,
        in_channels,
        out_channels,
        stride,
        affine,
        primitives,
        num_steps,
        num_init_nodes,
        output_op="concat",
        postprocess_op="conv_1x1",
        cell_shortcut=False,
        cell_shortcut_op="skip_connect",
    ):
        super().__init__()

        self.out_channels = out_channels
        self.stride = stride

        self.primitives = primitives
        self.num_init_nodes = num_init_nodes
        self.num_nodes = num_steps + num_init_nodes
        self.output_op = output_op

        # it's easier to calc edge indices with a longer ModuleList and some None
        self.edges = nn.ModuleList()
        for j in range(self.num_nodes):
            for i in range(self.num_nodes):
                if j > i:
                    if i < self.num_init_nodes:
                        self.edges.append(
                            Layer2MicroEdge(primitives, in_channels,
                                            out_channels, stride, affine))
                    else:
                        self.edges.append(
                            Layer2MicroEdge(primitives, out_channels,
                                            out_channels, 1, affine))
                else:
                    self.edges.append(None)

        if cell_shortcut and cell_shortcut_op != "none":
            self.shortcut = ops.get_op(cell_shortcut_op)(in_channels,
                                                         out_channels, stride,
                                                         affine)
        else:
            self.shortcut = None

        if self.output_op == "concat":
            self.postprocess = ops.get_op(postprocess_op)(out_channels *
                                                          num_steps,
                                                          out_channels,
                                                          stride=1,
                                                          affine=False)
Esempio n. 3
0
    def __init__(self, num_hid, primitives, share_w, batch_norm, shared_module,
                 **kwargs):
        #pylint: disable=invalid-name
        super(RNNSharedOp, self).__init__()
        self.num_hid = num_hid
        self.primitives = primitives
        self.p_ops = nn.ModuleList()
        self.share_w = share_w
        self.batch_norm = batch_norm
        if shared_module is None:
            if share_w:  # share weights between different activation function
                self.W = nn.Linear(num_hid, 2 * num_hid, bias=False)
                self.W.weight.data.uniform_(-INIT_RANGE, INIT_RANGE)
            else:
                self.Ws = nn.ModuleList([
                    nn.Linear(num_hid, 2 * num_hid, bias=False)
                    for _ in range(len(self.primitives))
                ])
                [
                    mod.weight.data.uniform_(-INIT_RANGE, INIT_RANGE)
                    for mod in self.Ws
                ]
        else:
            self.W = shared_module
            self.share_w = True

        if batch_norm:
            self.bn = nn.BatchNorm1d(2 * num_hid, affine=True)

        for primitive in self.primitives:
            op = ops.get_op(primitive)(**kwargs)
            self.p_ops.append(op)
Esempio n. 4
0
    def __init__(self,
                 C,
                 C_out,
                 stride,
                 primitives,
                 partial_channel_proportion=None):
        super(SharedOp, self).__init__()

        self.primitives = primitives
        self.stride = stride
        self.partial_channel_proportion = partial_channel_proportion

        if self.partial_channel_proportion is not None:
            expect(
                C % self.partial_channel_proportion == 0,
                "partial_channel_proportion must be divisible by #channels",
                ConfigException)
            expect(
                C_out % self.partial_channel_proportion == 0,
                "partial_channel_proportion must be divisible by #channels",
                ConfigException)
            C = C // self.partial_channel_proportion
            C_out = C_out // self.partial_channel_proportion

        self.p_ops = nn.ModuleList()
        for primitive in self.primitives:
            op = ops.get_op(primitive)(C, C_out, stride, False)
            if "pool" in primitive:
                op = nn.Sequential(op, nn.BatchNorm2d(C_out, affine=False))

            self.p_ops.append(op)
Esempio n. 5
0
    def __init__(self, primitives, in_channels, out_channels, stride, affine):
        super().__init__()

        assert "none" not in primitives, "Edge should not have `none` primitive"

        self.ops = nn.ModuleList(
            ops.get_op(prim)(in_channels, out_channels, stride, affine)
            for prim in primitives)
Esempio n. 6
0
 def __init__(self, C, C_out, stride, primitives):
     super(SharedOp, self).__init__()
     self.primitives = primitives
     self.stride = stride
     self.p_ops = nn.ModuleList()
     for primitive in self.primitives:
         op = ops.get_op(primitive)(C, C_out, stride, False)
         if "pool" in primitive:
             op = nn.Sequential(op, nn.BatchNorm2d(C_out, affine=False))
         self.p_ops.append(op)
    def __init__(self, C, C_out, stride, primitives):
        super(RobSharedOp, self).__init__()
        self.stride = stride
        self.primitives = primitives
        self.p_ops = nn.ModuleList()

        # Load the candidate operations and save in self.p_ops
        for primitive in self.primitives:
            op = ops.get_op(primitive)(C, C_out, stride, False)
            self.p_ops.append(op)
Esempio n. 8
0
    def __init__(
        self,
        expansion,
        C,
        C_out,
        stride,
        kernel_size,
        affine,
        activation="relu",
        inv_bottleneck=None,
        depth_wise=None,
        point_linear=None,
    ):
        super(MobileNetV2Block, self).__init__()
        self.expansion = expansion
        self.C = C
        self.C_out = C_out
        self.C_inner = make_divisible(C * expansion, 8)
        self.stride = stride
        self.kernel_size = kernel_size
        self.act_fn = get_op(activation)()

        self.inv_bottleneck = None
        if expansion != 1:
            self.inv_bottleneck = inv_bottleneck or nn.Sequential(
                nn.Conv2d(C, self.C_inner, 1, 1, 0, bias=False),
                nn.BatchNorm2d(self.C_inner), self.act_fn)

        self.depth_wise = depth_wise or nn.Sequential(
            nn.Conv2d(self.C_inner,
                      self.C_inner,
                      self.kernel_size,
                      stride,
                      padding=self.kernel_size // 2,
                      bias=False), nn.BatchNorm2d(self.C_inner), self.act_fn)

        self.point_linear = point_linear or nn.Sequential(
            nn.Conv2d(self.C_inner, C_out, 1, 1, 0, bias=False),
            nn.BatchNorm2d(C_out))

        self.shortcut = nn.Sequential()
        self.has_conv_shortcut = False
        if stride == 1 and C != C_out:
            self.has_conv_shortcut = True
            self.shortcut = nn.Sequential(
                nn.Conv2d(C,
                          C_out,
                          kernel_size=1,
                          stride=1,
                          padding=0,
                          bias=False),
                nn.BatchNorm2d(C_out),
            )
Esempio n. 9
0
    def __init__(
        self,
        op_cls,
        search_space,
        num_input_channels,
        num_out_channels,
        stride,
        prev_strides,
    ):
        super(RobSharedCell, self).__init__()

        self.search_space = search_space
        self.stride = stride
        self.num_input_channels = num_input_channels
        self.num_out_channels = num_out_channels
        self.prev_strides = prev_strides
        self._primitives = self.search_space.primitives
        self._num_nodes = self.search_space._num_nodes
        self.num_init_nodes = self.search_space.num_init_nodes
        self.preprocess_ops = nn.ModuleList()
        prev_strides = prev_strides[-self.num_init_nodes:]
        prev_strides = list(np.cumprod(list(reversed(prev_strides))))
        prev_strides.insert(0, 1)
        prev_strides = list(reversed(prev_strides[:len(num_input_channels)]))
        for prev_c, prev_s in zip(num_input_channels, prev_strides):
            preprocess = ops.get_op("skip_connect_2")(C=prev_c,
                                                      C_out=num_out_channels,
                                                      stride=prev_s,
                                                      affine=True)
            self.preprocess_ops.append(preprocess)

        self.edges = defaultdict(dict)
        self.edge_mod = torch.nn.Module()
        self.is_reduce = stride != 1

        # We save all the opertions on edges in an upper triangular matrix
        for from_ in range(self._num_nodes):
            for to_ in range(max(self.num_init_nodes, from_ + 1),
                             self._num_nodes):
                self.edges[from_][to_] = op_cls(
                    self.num_out_channels,
                    self.num_out_channels,
                    stride=self.stride if from_ < self.num_init_nodes else 1,
                    primitives=self._primitives,
                )
                self.edge_mod.add_module("f_{}_t_{}".format(from_, to_),
                                         self.edges[from_][to_])
                self._edge_name_pattern = re.compile("f_([0-9]+)_t_([0-9]+)")
Esempio n. 10
0
    def __init__(self, search_space, device, genotypes, schedule_cfg=None):
        super(GeneralGenotypeModel, self).__init__(schedule_cfg)
        self.search_space = search_space
        self.device = device

        if isinstance(genotypes, str):
            self.genotypes = list(
                genotype_from_str(genotypes, self.search_space))
        else:
            self.genotypes = copy.deepcopy(genotypes)
        model = []
        for geno in copy.deepcopy(self.genotypes):
            op = geno.pop("prim_type")
            geno.pop("spatial_size")
            model += [get_op(op)(**geno)]

        self.model = nn.ModuleList(model)

        self.model.to(self.device)
Esempio n. 11
0
    def __init__(self, search_space, device, op_cls, num_emb, num_hid,
                 share_primitive_w, batchnorm_step, batchnorm_out, **kwargs):
        super(RNNDiffSharedFromCell, self).__init__()

        self.num_emb = num_emb
        self.num_hid = num_hid
        self.batchnorm_step = batchnorm_step
        self.batchnorm_out = batchnorm_out
        self._steps = search_space.num_steps
        self._num_init = search_space.num_init_nodes
        self._primitives = search_space.shared_primitives

        # the first step, convert input x and previous hidden
        self.w_prev = nn.Linear(num_emb + num_hid, 2 * num_hid, bias=False)
        self.w_prev.weight.data.uniform_(-INIT_RANGE, INIT_RANGE)

        if self.batchnorm_step:
            # batchnorm after every step (just as in darts's implementation)
            # self.bn_steps = nn.ModuleList([nn.BatchNorm1d(num_hid, affine=False)
            #                                for _ in range(self._steps+1)])

            ## darts: (but seems odd...)
            self.bn_step = nn.BatchNorm1d(num_hid, affine=False)
            self.bn_steps = [self.bn_step] * (self._steps + 1)

        if self.batchnorm_out:
            # the out bn
            self.bn_out = nn.BatchNorm1d(num_hid, affine=True)

        self.step_weights = nn.ParameterList([
            nn.Parameter(torch.Tensor(num_hid, 2*num_hid)\
                         .uniform_(-INIT_RANGE, INIT_RANGE))
            for _ in range(self._steps)])
        [
            mod.weight.data.uniform_(-INIT_RANGE, INIT_RANGE)
            for mod in self.step_weights
        ]

        self.p_ops = nn.ModuleList()
        for primitive in self._primitives:
            op = ops.get_op(primitive)()
            self.p_ops.append(op)
Esempio n. 12
0
    def __init__(
        self,
        primitives,
        in_channels,
        out_channels,
        stride,
        affine,
        partial_channel_proportion=None,
    ):
        super(Layer2MicroDiffEdge, self).__init__()
        # assert "none" not in primitives, "Edge should not have `none` primitive"

        self.primitives = primitives
        self.stride = stride
        self.partial_channel_proportion = partial_channel_proportion

        if self.partial_channel_proportion is not None:
            expect(
                in_channels % self.partial_channel_proportion == 0,
                "partial_channel_proportion must be divisible by #channels",
                ConfigException,
            )
            expect(
                out_channels % self.partial_channel_proportion == 0,
                "partial_channel_proportion must be divisible by #channels",
                ConfigException,
            )
            in_channels = in_channels // self.partial_channel_proportion
            out_channels = out_channels // self.partial_channel_proportion

        self.p_ops = nn.ModuleList()
        for primitive in self.primitives:
            op = ops.get_op(primitive)(in_channels, out_channels, stride,
                                       False)
            if "pool" in primitive:
                op = nn.Sequential(op,
                                   nn.BatchNorm2d(out_channels, affine=False))

            self.p_ops.append(op)
Esempio n. 13
0
    def __new__(cls, prim_type, spatial_size, C, C_out, stride, affine,
                **kwargs):
        position_params = ["C", "C_out", "stride", "affine"]
        prim_constructor = get_op(prim_type)
        prim_sig = signature(prim_constructor)
        params = prim_sig.parameters
        for name, param in params.items():
            if param.default != inspect._empty:
                if name in position_params:
                    continue
                if kwargs.get(name) is None:
                    kwargs[name] = param.default
            else:
                assert name in position_params or name in kwargs, \
                    "{} is a non-default parameter which should be provided explicitly.".format(
                    name)

        kwargs = {k: v for k, v in kwargs.items() if v is not None}

        assert set(params.keys()) == set(
            position_params + list(kwargs.keys())),\
            ("The passed parameters are different from the formal parameter list of primitive "
             "type `{}`, expected {}, got {}").format(
                 prim_type,
                 str(params.keys()),
                 str(position_params + list(kwargs.keys()))
             )

        kwargs = tuple(sorted([(k, v) for k, v in kwargs.items()]))
        return super(Prim, cls).__new__(
            cls,
            prim_type,
            int(spatial_size),
            int(C),
            int(C_out),
            int(stride),
            affine,
            kwargs,
        )
Esempio n. 14
0
    def __init__(self,
                 search_space,
                 device,
                 rollout_type,
                 cell_cls,
                 op_cls,
                 gpus=tuple(),
                 num_classes=10,
                 init_channels=16,
                 stem_multiplier=3,
                 max_grad_norm=5.0,
                 dropout_rate=0.1,
                 use_stem="conv_bn_3x3",
                 stem_stride=1,
                 stem_affine=True,
                 preprocess_op_type=None,
                 cell_use_preprocess=True,
                 cell_group_kwargs=None,
                 cell_use_shortcut=False,
                 cell_shortcut_op_type="skip_connect"):
        super(SharedNet, self).__init__(search_space, device, rollout_type)
        nn.Module.__init__(self)

        # optionally data parallelism in SharedNet
        self.gpus = gpus

        self.num_classes = num_classes
        # init channel number of the first cell layers,
        # x2 after every reduce cell
        self.init_channels = init_channels
        # channels of stem conv / init_channels
        self.stem_multiplier = stem_multiplier
        self.use_stem = use_stem
        # possible cell group kwargs
        self.cell_group_kwargs = cell_group_kwargs
        # possible inter-cell shortcut
        self.cell_use_shortcut = cell_use_shortcut
        self.cell_shortcut_op_type = cell_shortcut_op_type

        # training
        self.max_grad_norm = max_grad_norm
        self.dropout_rate = dropout_rate

        # search space configs
        self._num_init = self.search_space.num_init_nodes
        self._cell_layout = self.search_space.cell_layout
        self._reduce_cgs = self.search_space.reduce_cell_groups
        self._num_layers = self.search_space.num_layers

        ## initialize sub modules
        if not self.use_stem:
            c_stem = 3
            init_strides = [1] * self._num_init
        elif isinstance(self.use_stem, (list, tuple)):
            self.stems = []
            c_stem = self.stem_multiplier * self.init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stems.append(
                    ops.get_op(stem_type)(c_in,
                                          c_stem,
                                          stride=stem_stride,
                                          affine=stem_affine))
            self.stems = nn.ModuleList(self.stems)
            init_strides = [stem_stride] * self._num_init
        else:
            c_stem = self.stem_multiplier * self.init_channels
            self.stem = ops.get_op(self.use_stem)(3,
                                                  c_stem,
                                                  stride=stem_stride,
                                                  affine=stem_affine)
            init_strides = [1] * self._num_init

        self.cells = nn.ModuleList()
        num_channels = self.init_channels
        prev_num_channels = [c_stem] * self._num_init
        strides = [
            2 if self._is_reduce(i_layer) else 1
            for i_layer in range(self._num_layers)
        ]

        for i_layer, stride in enumerate(strides):
            if stride > 1:
                num_channels *= stride
            if cell_group_kwargs is not None:
                # support passing in different kwargs when instantializing
                # cell class for different cell groups
                kwargs = {
                    k: v
                    for k, v in cell_group_kwargs[
                        self._cell_layout[i_layer]].items()
                }
            else:
                kwargs = {}
            # A patch: Can specificy input/output channels by hand in configuration,
            # instead of relying on the default
            # "whenever stride/2, channelx2 and mapping with preprocess operations" assumption
            _num_channels = num_channels if "C_in" not in kwargs \
                            else kwargs.pop("C_in")
            _num_out_channels = num_channels if "C_out" not in kwargs \
                                else kwargs.pop("C_out")
            cell = cell_cls(op_cls,
                            self.search_space,
                            layer_index=i_layer,
                            num_channels=_num_channels,
                            num_out_channels=_num_out_channels,
                            prev_num_channels=tuple(prev_num_channels),
                            stride=stride,
                            prev_strides=init_strides + strides[:i_layer],
                            use_preprocess=cell_use_preprocess,
                            preprocess_op_type=preprocess_op_type,
                            use_shortcut=cell_use_shortcut,
                            shortcut_op_type=cell_shortcut_op_type,
                            **kwargs)
            prev_num_channel = cell.num_out_channel()
            prev_num_channels.append(prev_num_channel)
            prev_num_channels = prev_num_channels[1:]
            self.cells.append(cell)

        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        if self.dropout_rate and self.dropout_rate > 0:
            self.dropout = nn.Dropout(p=self.dropout_rate)
        else:
            self.dropout = ops.Identity()
        self.classifier = nn.Linear(prev_num_channel, self.num_classes)

        self.to(self.device)
Esempio n. 15
0
    def __init__(
        self,
        search_space,  # layer2
        device,
        genotypes,  # layer2
        micro_model_type="micro-dense-model",
        micro_model_cfg={},
        num_classes=10,
        init_channels=36,
        stem_multiplier=1,
        dropout_rate=0.0,
        dropout_path_rate=0.0,
        use_stem="conv_bn_3x3",
        stem_stride=1,
        stem_affine=True,
        auxiliary_head=False,
        auxiliary_cfg=None,
        schedule_cfg=None,
    ):
        super(MacroStagewiseFinalModel, self).__init__(schedule_cfg)

        self.macro_ss = search_space.macro_search_space
        self.micro_ss = search_space.micro_search_space
        self.device = device
        assert isinstance(genotypes, str)
        self.genotypes_str = genotypes
        self.macro_g, self.micro_g = genotype_from_str(genotypes, search_space)

        # micro model (cell) class
        micro_model_cls = FinalModel.get_class_(micro_model_type)  # cell type

        self.num_classes = num_classes
        self.init_channels = init_channels
        self.stem_multiplier = stem_multiplier
        self.stem_stride = stem_stride
        self.stem_affine = stem_affine
        self.use_stem = use_stem

        # training
        self.dropout_rate = dropout_rate
        self.dropout_path_rate = dropout_path_rate

        self.auxiliary_head = auxiliary_head

        self.overall_adj = self.macro_ss.parse_overall_adj(self.macro_g)
        self.layer_widths = [float(w) for w in self.macro_g.width.split(",")]

        self.micro_model_cfg = micro_model_cfg
        if "postprocess" in self.micro_model_cfg.keys():
            self.cell_use_postprocess = self.micro_model_cfg["postprocess"]
        else:
            self.cell_use_postprocess = False

        # sort channels out
        assert self.stem_multiplier == 1, "Cannot handle stem_multiplier != 1 now"
        self.input_channel_list = [self.init_channels]
        for i in range(1, self.macro_ss.num_layers):
            self.input_channel_list.append(
                self.input_channel_list[i - 1] *
                2 if self._is_reduce(i - 1) else self.input_channel_list[i -
                                                                         1])
        for i in range(self.macro_ss.num_layers):
            self.input_channel_list[i] = int(
                self.input_channel_list[i] * self.layer_widths[i]
                if not self._is_reduce(i) else self.input_channel_list[i] *
                self.layer_widths[i - 1])

        self.output_channel_list = self.input_channel_list[1:] + [
            self.input_channel_list[-1]
        ]

        # construct cells
        if not self.use_stem:
            raise NotImplementedError
            c_stem = 3
        elif isinstance(self.use_stem, (list, tuple)):
            raise NotImplementedError
            self.stems = []
            c_stem = self.stem_multiplier * self.init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stems.append(
                    ops.get_op(stem_type)(c_in,
                                          c_stem,
                                          stride=stem_stride,
                                          affine=stem_affine))
            self.stem = nn.Sequential(self.stems)
        else:
            self.stem = ops.get_op(self.use_stem)(3,
                                                  self.input_channel_list[0],
                                                  stride=stem_stride,
                                                  affine=stem_affine)

        self.extra_stem = ops.get_op("nor_conv_1x1")(
            self.input_channel_list[0],
            self.input_channel_list[0] * self.micro_ss.num_steps,
            stride=1,
            affine=True,
        )

        # For sink-connect, don't init all cells, just init connected cells

        connected_cells = []
        for cell_idx in range(1, self.macro_ss.num_layers + 2):
            if len(self.overall_adj[cell_idx].nonzero()[0]) > 0:
                connected_cells.append(self.overall_adj[cell_idx].nonzero()[0])
        # -1 to make the 1st element 0
        self.connected_cells = np.concatenate(connected_cells)[1:] - 1
        """
        ininitialize cells, only connected cells are initialized
        also use `use_next_stage_width` to handle the disalignment of width due to width search
        """
        self.cells = nn.ModuleList()
        self.micro_arch_list = self.micro_ss.rollout_from_genotype(
            self.micro_g).arch
        for i_layer in range(self.macro_ss.num_layers):
            stride = 2 if self._is_reduce(i_layer) else 1
            connected_is_reduce = [
                self._is_reduce(i) for i in self.connected_cells
            ]
            # the layer-idx to use next stage's width: the last cell before the redudction cell in each stage
            use_next_stage_width_layer_idx = self.connected_cells[
                np.argwhere(np.array(connected_is_reduce)).reshape(-1) - 1]
            reduction_layer_idx = self.connected_cells[np.argwhere(
                np.array(connected_is_reduce)
            ).reshape(
                -1)]  #  find reudction cells are the 1-th in connected cells
            if not self.cell_use_postprocess:
                next_stage_widths = (np.array(
                    self.output_channel_list)[self.macro_ss.stages_begin[1:]]
                                     // 2)  # preprocess, so no //2
            else:
                next_stage_widths = (
                    np.array(self.output_channel_list)[
                        self.macro_ss.stages_begin[1:]] // 2
                )  # the width to use for `ues_next_stage_width`, the reduction cell is of expansion 2, so //2
            use_next_stage_width = (
                next_stage_widths[np.argwhere(
                    use_next_stage_width_layer_idx == i_layer).reshape(-1)] if
                np.argwhere(use_next_stage_width_layer_idx == i_layer).size > 0
                else None)
            input_channel_list_n = np.array(self.input_channel_list)
            input_channel_list_n[
                reduction_layer_idx] = next_stage_widths  # input of the reduction should be half of the next stage's width

            cg_idx = self.macro_ss.cell_layout[i_layer]
            if i_layer not in self.connected_cells:
                continue
            # contruct micro cell
            cell = micro_model_cls(
                self.micro_ss,
                self.micro_arch_list[cg_idx],
                num_input_channels=int(
                    input_channel_list_n[i_layer]
                ),  # TODO: input_channel_list is of type: np.int64
                num_out_channels=self.output_channel_list[i_layer],
                stride=stride,
                use_next_stage_width=use_next_stage_width,
                is_last_cell=True
                if i_layer == self.connected_cells[-1] else False,
                is_first_cell=True
                if i_layer == self.connected_cells[0] else False,
                skip_cell=False,
                **micro_model_cfg)
            # assume non-reduce cell does not change channel number

            self.cells.append(cell)
            # add auxiliary head

        # connected_cells has 1 more element [0] than the self.cells
        if self.auxiliary_head:
            self.where_aux_head = self.connected_cells[(2 * len(self.cells)) //
                                                       3]
            extra_expansion_for_aux = (
                1 if self.cell_use_postprocess else self.micro_ss.num_steps
            )  # if use preprocess, aux head's input ch num should change accordingly
            # aux head is connected to last cell's output
            if auxiliary_head == "imagenet":
                self.auxiliary_net = AuxiliaryHeadImageNet(
                    input_channel_list_n[self.where_aux_head] *
                    extra_expansion_for_aux, num_classes,
                    **(auxiliary_cfg or {}))
            else:
                self.auxiliary_net = AuxiliaryHead(
                    input_channel_list_n[self.where_aux_head] *
                    extra_expansion_for_aux, num_classes,
                    **(auxiliary_cfg or {}))

        self.lastact = nn.Identity()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        if self.dropout_rate and self.dropout_rate > 0:
            self.dropout = nn.Dropout(p=self.dropout_rate)
        else:
            self.dropout = ops.Identity()

        if not self.cell_use_postprocess:
            self.classifier = nn.Linear(
                self.output_channel_list[-1] * self.micro_ss.num_steps,
                self.num_classes)
        else:
            self.classifier = nn.Linear(self.output_channel_list[-1],
                                        self.num_classes)
        self.to(self.device)

        # for flops calculation
        self.total_flops = 0
        self._flops_calculated = False
        self._set_hook()
    def __init__(
        self,
        search_space,
        device,
        rollout_type="dense_rob",
        gpus=tuple(),
        num_classes=10,
        init_channels=36,
        stem_multiplier=1,
        max_grad_norm=5.0,
        drop_rate=0.2,
        drop_out_rate=0.1,
        use_stem="conv_bn_3x3",
        stem_stride=1,
        stem_affine=True,
        candidate_eval_no_grad=False,  # need grad in eval to craft adv examples
            calib_bn_batch=0,
            calib_bn_num=0
    ):
        super(RobSharedNet, self).__init__(search_space, device, rollout_type)

        nn.Module.__init__(self)

        cell_cls = RobSharedCell
        op_cls = RobSharedOp
        # optionally data parallelism in SharedNet
        self.gpus = gpus

        self.search_space = search_space
        self.num_classes = num_classes
        self.device = device
        self.drop_rate = drop_rate
        self.drop_out_rate = drop_out_rate
        self.init_channels = init_channels

        # channels of stem conv / init_channels
        self.stem_multiplier = stem_multiplier
        self.use_stem = use_stem

        # training
        self.max_grad_norm = max_grad_norm

        # search space configs
        self._ops_choices = self.search_space.primitives
        self._num_layers = self.search_space.num_layers

        self._num_init = self.search_space.num_init_nodes
        self._num_layers = self.search_space.num_layers
        self._cell_layout = self.search_space.cell_layout

        self.calib_bn_batch = calib_bn_batch
        self.calib_bn_num = calib_bn_num
        if self.calib_bn_num > 0 and self.calib_bn_batch > 0:
            self.logger.warn("`calib_bn_num` and `calib_bn_batch` set simultaneously, "
                             "will use `calib_bn_num` only")

        ## initialize sub modules
        if not self.use_stem:
            c_stem = 3
            init_strides = [1] * self._num_init
        elif isinstance(self.use_stem, (list, tuple)):
            self.stems = []
            c_stem = self.stem_multiplier * self.init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stems.append(
                    ops.get_op(stem_type)(
                        c_in, c_stem, stride=stem_stride, affine=stem_affine
                    )
                )
            self.stems = nn.ModuleList(self.stems)
            init_strides = [stem_stride] * self._num_init
        else:
            c_stem = self.stem_multiplier * self.init_channels
            self.stem = ops.get_op(self.use_stem)(
                3, c_stem, stride=stem_stride, affine=stem_affine
            )
            init_strides = [1] * self._num_init

        self.cells = nn.ModuleList()
        num_channels = self.init_channels
        prev_num_channels = [c_stem] * self._num_init
        strides = [
            2 if self._is_reduce(i_layer) else 1 for i_layer in range(self._num_layers)
        ]

        for i_layer, stride in enumerate(strides):
            if stride > 1:
                num_channels *= stride
            num_out_channels = num_channels
            cell = cell_cls(
                op_cls,
                self.search_space,
                num_input_channels=prev_num_channels,
                num_out_channels=num_out_channels,
                stride=stride,
                prev_strides=init_strides + strides[:i_layer],
            )

            self.cells.append(cell)
            prev_num_channel = cell.num_out_channel()
            prev_num_channels.append(prev_num_channel)
            prev_num_channels = prev_num_channels[1:]

        self.lastact = nn.Identity()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        if self.drop_rate and self.drop_rate > 0:
            self.dropout = nn.Dropout(p=self.drop_rate)
        else:
            self.dropout = ops.Identity()
        self.classifier = nn.Linear(prev_num_channels[-1], self.num_classes)
        self.to(self.device)

        self.candidate_eval_no_grad = candidate_eval_no_grad
        self.assembled = 0
        self.candidate_map = weakref.WeakValueDictionary()

        self.set_hook()
        self._flops_calculated = False
        self.total_flops = 0
Esempio n. 17
0
    def __init__(
        self,
        search_space,
        device,
        genotypes,
        num_classes=10,
        init_channels=36,
        stem_multiplier=1,
        dropout_rate=0.0,
        dropout_path_rate=0.0,
        use_stem="conv_bn_3x3",
        stem_stride=1,
        stem_affine=True,
        schedule_cfg=None,
    ):
        super(DenseRobFinalModel, self).__init__(schedule_cfg)

        self.search_space = search_space
        self.device = device
        assert isinstance(genotypes, str)
        genotypes = genotype_from_str(genotypes, self.search_space)
        self.arch_list = self.search_space.rollout_from_genotype(
            genotypes).arch

        self.num_classes = num_classes
        self.init_channels = init_channels
        self.stem_multiplier = stem_multiplier
        self.use_stem = use_stem

        # training
        self.dropout_rate = dropout_rate
        self.dropout_path_rate = dropout_path_rate

        # search space configs
        self._num_init = self.search_space.num_init_nodes
        self._num_layers = self.search_space.num_layers

        ## initialize sub modules
        if not self.use_stem:
            c_stem = 3
            init_strides = [1] * self._num_init
        elif isinstance(self.use_stem, (list, tuple)):
            self.stems = []
            c_stem = self.stem_multiplier * self.init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stems.append(
                    ops.get_op(stem_type)(c_in,
                                          c_stem,
                                          stride=stem_stride,
                                          affine=stem_affine))
            self.stems = nn.ModuleList(self.stems)
            init_strides = [stem_stride] * self._num_init
        else:
            c_stem = self.stem_multiplier * self.init_channels
            self.stem = ops.get_op(self.use_stem)(3,
                                                  c_stem,
                                                  stride=stem_stride,
                                                  affine=stem_affine)
            init_strides = [1] * self._num_init

        self.cells = nn.ModuleList()
        num_channels = self.init_channels
        prev_num_channels = [c_stem] * self._num_init
        strides = [
            2 if self._is_reduce(i_layer) else 1
            for i_layer in range(self._num_layers)
        ]

        for i_layer, stride in enumerate(strides):
            if stride > 1:
                num_channels *= stride
            num_out_channels = num_channels
            kwargs = {}
            cg_idx = self.search_space.cell_layout[i_layer]

            cell = DenseRobCell(
                self.search_space,
                self.arch_list[cg_idx],
                # num_channels=num_channels,
                num_input_channels=prev_num_channels,
                num_out_channels=num_out_channels,
                # prev_num_channels=tuple(prev_num_channels),
                prev_strides=init_strides + strides[:i_layer],
                stride=stride,
                **kwargs)

            prev_num_channel = cell.num_out_channel()
            prev_num_channels.append(prev_num_channel)
            prev_num_channels = prev_num_channels[1:]
            self.cells.append(cell)

        self.lastact = nn.Identity()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        if self.dropout_rate and self.dropout_rate > 0:
            self.dropout = nn.Dropout(p=self.dropout_rate)
        else:
            self.dropout = ops.Identity()
        self.classifier = nn.Linear(prev_num_channels[-1], self.num_classes)
        self.to(self.device)

        # for flops calculation
        self.total_flops = 0
        self._flops_calculated = False
        self.set_hook()
Esempio n. 18
0
    def __init__(self, search_space, device, op_cls, num_emb, num_hid,
                 share_from_weights, batchnorm_step, batchnorm_edge,
                 batchnorm_out, genotypes, **kwargs):
        super(RNNGenotypeCell, self).__init__()
        self.genotypes = genotypes

        self.search_space = search_space

        self.num_emb = num_emb
        self.num_hid = num_hid
        self.batchnorm_step = batchnorm_step
        self.batchnorm_edge = batchnorm_edge
        self.batchnorm_out = batchnorm_out
        self.share_from_w = share_from_weights
        self._steps = search_space.num_steps
        self._num_init = search_space.num_init_nodes

        # the first step, convert input x and previous hidden
        self.w_prev = nn.Linear(num_emb + num_hid, 2 * num_hid, bias=False)
        self.w_prev.weight.data.uniform_(-INIT_RANGE, INIT_RANGE)

        if self.batchnorm_edge:
            # batchnorm on each edge/connection
            # when `num_node_inputs==1`, there is `step + 1` edges
            # the first bn
            self.bn_prev = nn.BatchNorm1d(num_emb + num_hid, affine=True)
            # other bn
            self.bn_edges = nn.ModuleList([
                nn.BatchNorm1d(num_emb + num_hid, affine=True)
                for _ in range(len(self.genotypes[0]))
            ])

        if self.batchnorm_step:
            # batchnorm after every step (as in darts's implementation)
            self.bn_steps = nn.ModuleList([
                nn.BatchNorm1d(num_hid, affine=False)
                for _ in range(self._steps + 1)
            ])

        if self.batchnorm_out:
            # the out bn
            self.bn_out = nn.BatchNorm1d(num_hid, affine=True)

        if self.share_from_w:
            # actually, as `num_node_inputs==1`, thus only one from node is used each step
            # `share_from_w==True/False` are equivalent in final training...
            self.step_weights = nn.ModuleList([
                nn.Linear(num_hid, 2 * num_hid, bias=False)
                for _ in range(self._steps)
            ])
            [
                mod.weight.data.uniform_(-INIT_RANGE, INIT_RANGE)
                for mod in self.step_weights
            ]

        # initiatiate op on edges
        self.Ws = nn.ModuleList()
        self.ops = nn.ModuleList()
        genotype_, _ = self.genotypes

        for op_type, _, _ in genotype_:
            # edge weights
            op = ops.get_op(op_type)()
            self.ops.append(op)
            if not self.share_from_w:
                W = nn.Linear(self.num_hid, 2 * self.num_hid, bias=False)
                W.weight.data.uniform_(-INIT_RANGE, INIT_RANGE)
                self.Ws.append(W)
Esempio n. 19
0
    def __init__(
        self,
        in_channels,
        out_channels,
        stride,
        affine,
        primitives,
        num_steps,
        num_init_nodes,
        output_op="concat",
        width_choice=[1.0],
        postprocess=False,
        process_op="conv_1x1",
        cell_shortcut=False,
        cell_shortcut_op="skip_connect",
        partial_channel_proportion=None,
        cell_idx=None,
        use_next_stage_width=None,
    ):
        super(Layer2MicroDiffCell, self).__init__()

        self.out_channels = out_channels
        self.stride = stride

        self.primitives = primitives
        self.num_init_nodes = num_init_nodes
        self.num_steps = num_steps
        self.num_nodes = num_steps + num_init_nodes
        self.output_op = output_op

        self.cell_idx = cell_idx
        self.use_next_stage_width = use_next_stage_width

        self.postprocess = postprocess

        self.width_choice = width_choice
        assert 1.0 in self.width_choice, "Must have a width choice with 100% channels"

        self.register_buffer(
            "channel_masks",
            torch.zeros(len(self.width_choice), self.out_channels))
        for i, w in enumerate(self.width_choice):
            self.channel_masks[i][:int(w * self.out_channels)] = 1.0

        self.partial_channel_proportion = partial_channel_proportion
        assert (partial_channel_proportion is None
                )  # currently dont support partial channel

        # it's easier to calc edge indices with a longer ModuleList and some None
        self.edges = nn.ModuleList()
        for j in range(self.num_nodes):
            for i in range(self.num_nodes):
                if j > i:
                    if i < self.num_init_nodes:
                        self.edges.append(
                            Layer2MicroDiffEdge(primitives, in_channels,
                                                out_channels, stride, affine))
                    else:
                        self.edges.append(
                            Layer2MicroDiffEdge(primitives, out_channels,
                                                out_channels, 1, affine))
                else:
                    self.edges.append(None)

        if cell_shortcut and cell_shortcut_op != "none":
            if not self.postprocess:
                self.shortcut = ops.get_op(cell_shortcut_op)(
                    in_channels * self.num_steps,
                    out_channels * self.num_steps,
                    stride,
                    affine,
                )
            else:
                self.shortcut = ops.get_op(cell_shortcut_op)(
                    in_channels,
                    out_channels,
                    stride,
                    affine,
                )
        else:
            self.shortcut = None

        if self.output_op == "concat":
            """
            no matter post/preprocess [4c,c]
            however with reduction cell, out_channel = 2*in_c,
            so when using preprocess, should use in_channels
            """
            if not self.postprocess:
                self.process = ops.get_op(process_op)(in_channels * num_steps,
                                                      in_channels,
                                                      stride=1,
                                                      affine=False)
            else:
                self.process = ops.get_op(process_op)(out_channels * num_steps,
                                                      out_channels,
                                                      stride=1,
                                                      affine=False)

        self.total_flops = 0.0
Esempio n. 20
0
    def __init__(
        self,
        search_space,
        arch,
        num_input_channels,
        num_out_channels,
        stride,
        postprocess=False,  # default use preprocess
        process_op_type="nor_conv_1x1",
        use_shortcut=True,
        shortcut_op_type="skip_connect",
        # applied on every cell at the end of the stage, before the reduction cell, to ensure x2 ch in reduction
        use_next_stage_width=None,  # applied on every cell at the end of the stage, before the reduction cell, to ensure x2 ch in reduction
        is_last_cell=False,
        is_first_cell=False,
        skip_cell=False,
        schedule_cfg=None,
    ):
        super(MicroDenseCell, self).__init__(schedule_cfg)

        self.search_space = search_space
        self.arch = arch
        self.stride = stride
        self.postprocess = postprocess
        self.process_op_type = process_op_type

        self.use_shortcut = use_shortcut
        self.shortcut_op_type = shortcut_op_type

        self.num_steps = self.search_space.num_steps

        self._num_nodes = self.search_space._num_nodes
        self._primitives = self.search_space.primitives
        self._num_init_nodes = self.search_space.num_init_nodes

        self.is_last_cell = is_last_cell
        self.is_first_cell = is_first_cell
        self.skip_cell = skip_cell

        if use_next_stage_width is not None:
            self.use_next_stage_width = use_next_stage_width.item()
        else:
            self.use_next_stage_width = use_next_stage_width

        if self.use_next_stage_width:
            # when use_next_stage_width, should apply to normal_cell, in_c == out_c
            assert num_input_channels == num_out_channels

        #  self.num_input_channels = num_input_channels if self.use_next_stage_width is None else self.use_next_stage_width
        #  self.num_out_channels = num_out_channels if self.use_next_stage_width is None else self.use_next_stage_width
        self.num_input_channels = num_input_channels
        self.num_out_channels = num_out_channels

        if self.use_shortcut:
            """
            no 'use-next-stage-width' is applied in to the cell-wise shortcut,
            since the 'use-next-stage-width' only happens in last cell before reduction,
            the shortcut is usually plain shortcut and could not handle ch disalignment

            when using preprocess, the shortcut is of 4C width;
            when using postprocess, the shortcut is of C witdh;
            """
            if not self.postprocess:
                self.shortcut_reduction_op = ops.get_op(self.shortcut_op_type)(
                    C=num_input_channels * self.num_steps,
                    C_out=num_out_channels * self.num_steps,
                    stride=self.stride,
                    affine=True,
                )
            else:
                self.shortcut_reduction_op = ops.get_op(self.shortcut_op_type)(
                    C=num_input_channels,
                    C_out=num_out_channels,
                    stride=self.stride,
                    affine=True,
                )

        self.edges = _defaultdict_3()
        self.edge_mod = torch.nn.Module(
        )  # a stub wrapping module of all the edges
        for from_ in range(self._num_nodes):
            for to_ in range(max(self._num_init_nodes, from_ + 1),
                             self._num_nodes):
                num_input_channels = (self.num_input_channels
                                      if from_ < self._num_init_nodes else
                                      self.num_out_channels)
                stride = self.stride if from_ < self._num_init_nodes else 1
                for op_ind in np.where(self.arch[to_, from_])[0]:
                    op_type = self._primitives[op_ind]
                    self.edges[from_][to_][op_type] = ops.get_op(op_type)(
                        # when applying the preprocess and cell `use-next-stage-width` the op width should also align with next stage width
                        self.use_next_stage_width if
                        (self.use_next_stage_width is not None
                         and not self.postprocess) else num_input_channels,
                        self.use_next_stage_width if
                        (self.use_next_stage_width is not None
                         and not self.postprocess) else self.num_out_channels,
                        stride=stride,
                        affine=False,
                    )
                    self.edge_mod.add_module(
                        "f_{}_t_{}_{}".format(from_, to_, op_type),
                        self.edges[from_][to_][op_type],
                    )
        self._edge_name_pattern = re.compile(
            "f_([0-9]+)_t_([0-9]+)_([a-z0-9_-]+)")

        self.use_concat = self.search_space.concat_op == "concat"
        if self.use_concat:
            if not self.postprocess:
                # currently, map the concatenated output to num_out_channels
                self.process_op = ops.get_op(self.process_op_type)(
                    C=self.num_input_channels * self.search_space.num_steps,
                    # change outprocess op's output-width to align with next stage's width
                    # ensuring the reduction cell meets in_c*2 == out_c
                    C_out=self.use_next_stage_width
                    if self.use_next_stage_width is not None else
                    self.num_input_channels,
                    stride=1,
                    affine=True,
                )
            else:
                self.process_op = ops.get_op(self.process_op_type)(
                    C=self.num_out_channels * self.search_space.num_steps,
                    # change outprocess op's output-width to align with next stage's width
                    # ensuring the reduction cell meets in_c*2 == out_c
                    C_out=self.use_next_stage_width
                    if self.use_next_stage_width is not None else
                    self.num_out_channels,
                    stride=1,
                    affine=True,
                )
Esempio n. 21
0
    def __init__(self, op_cls, search_space, layer_index, num_channels,
                 num_out_channels, prev_num_channels, stride, prev_strides,
                 use_preprocess, preprocess_op_type, use_shortcut,
                 shortcut_op_type, **op_kwargs):
        super(SharedCell, self).__init__()
        self.search_space = search_space
        self.stride = stride
        self.is_reduce = stride != 1
        self.num_channels = num_channels
        self.num_out_channels = num_out_channels
        self.layer_index = layer_index
        self.use_preprocess = use_preprocess
        self.preprocess_op_type = preprocess_op_type
        self.use_shortcut = use_shortcut
        self.shortcut_op_type = shortcut_op_type
        self.op_kwargs = op_kwargs

        self._steps = self.search_space.get_layer_num_steps(layer_index)
        self._num_init = self.search_space.num_init_nodes
        if not self.search_space.cellwise_primitives:
            # the same set of primitives for different cg group
            self._primitives = self.search_space.shared_primitives
        else:
            # different set of primitives for different cg group
            self._primitives = \
                self.search_space.cell_shared_primitives[self.search_space.cell_layout[layer_index]]

        # initialize self.concat_op, self._out_multiplier (only used for discrete super net)
        self.concat_op = ops.get_concat_op(self.search_space.concat_op)
        if not self.concat_op.is_elementwise:
            expect(
                not self.search_space.loose_end,
                "For shared weights weights manager, when non-elementwise concat op do not "
                "support loose-end search space")
            self._out_multipler = self._steps if not self.search_space.concat_nodes \
                                  else len(self.search_space.concat_nodes)
        else:
            # elementwise concat op. e.g. sum, mean
            self._out_multipler = 1

        self.preprocess_ops = nn.ModuleList()
        prev_strides = list(np.cumprod(list(reversed(prev_strides))))
        prev_strides.insert(0, 1)
        prev_strides = reversed(prev_strides[:len(prev_num_channels)])
        for prev_c, prev_s in zip(prev_num_channels, prev_strides):
            if not self.use_preprocess:
                # stride/channel not handled!
                self.preprocess_ops.append(ops.Identity())
                continue
            if self.preprocess_op_type is not None:
                # specificy other preprocess op
                preprocess = ops.get_op(self.preprocess_op_type)(
                    C=prev_c,
                    C_out=num_channels,
                    stride=int(prev_s),
                    affine=False)
            else:
                if prev_s > 1:
                    # need skip connection, and is not the connection from the input image
                    preprocess = ops.FactorizedReduce(C_in=prev_c,
                                                      C_out=num_channels,
                                                      stride=prev_s,
                                                      affine=False)
                else:  # prev_c == _steps * num_channels or inputs
                    preprocess = ops.ReLUConvBN(C_in=prev_c,
                                                C_out=num_channels,
                                                kernel_size=1,
                                                stride=1,
                                                padding=0,
                                                affine=False)
            self.preprocess_ops.append(preprocess)
        assert len(self.preprocess_ops) == self._num_init

        if self.use_shortcut:
            self.shortcut_reduction_op = ops.get_op(self.shortcut_op_type)(
                C=prev_num_channels[-1],
                C_out=self.num_out_channel(),
                stride=self.stride,
                affine=True)

        self.edges = defaultdict(dict)
        self.edge_mod = torch.nn.Module(
        )  # a stub wrapping module of all the edges
        for i_step in range(self._steps):
            to_ = i_step + self._num_init
            for from_ in range(to_):
                self.edges[from_][to_] = op_cls(
                    self.num_channels,
                    self.num_out_channels,
                    stride=self.stride if from_ < self._num_init else 1,
                    primitives=self._primitives,
                    **op_kwargs)
                self.edge_mod.add_module("f_{}_t_{}".format(from_, to_),
                                         self.edges[from_][to_])

        self._edge_name_pattern = re.compile("f_([0-9]+)_t_([0-9]+)")
Esempio n. 22
0
    def __init__(
            self,
            search_space,  # type: Layer2SearchSpace
            device,
            rollout_type="layer2",
            init_channels=16,
            # classifier
            num_classes=10,
            dropout_rate=0.0,
            # stem
            use_stem="conv_bn_3x3",
            stem_stride=1,
            stem_affine=True,
            stem_multiplier=1,
            max_grad_norm=5.0,
            # candidate
            candidate_eval_no_grad=True,
            # micro-cell cfg
            micro_cell_cfg={},
            # schedule
            schedule_cfg=None,
            gpus=tuple(),
            multiprocess=False,
    ):
        super(Layer2MacroDiffSupernet,
              self).__init__(search_space, device, rollout_type, schedule_cfg)
        nn.Module.__init__(self)

        self.macro_search_space = (search_space.macro_search_space
                                   )  # type: StagewiseMacroSearchSpace
        self.micro_search_space = (search_space.micro_search_space
                                   )  # type: DenseMicroSearchSpace

        self.num_cell_groups = self.macro_search_space.num_cell_groups
        self.cell_layout = self.macro_search_space.cell_layout
        self.reduce_cell_groups = self.macro_search_space.reduce_cell_groups

        self.max_grad_norm = max_grad_norm

        self.candidate_eval_no_grad = candidate_eval_no_grad

        self.micro_cell_cfg = micro_cell_cfg
        if "postprocess" in micro_cell_cfg.keys():
            self.cell_use_postprocess = micro_cell_cfg["postprocess"]
        else:
            self.cell_use_postprocess = (
                False  # defualt use preprocess if not specified in cfg
            )

        self.gpus = gpus
        self.multiprocess = multiprocess

        # make stem
        self.use_stem = use_stem
        if not self.use_stem:
            c_stem = 3
        elif isinstance(self.use_stem, (list, tuple)):
            self.stem = []
            c_stem = stem_multiplier * init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stem.append(
                    ops.get_op(stem_type)(c_in,
                                          c_stem,
                                          stride=stem_stride,
                                          affine=stem_affine))

            self.stem = nn.Sequential(*self.stem)
        else:
            c_stem = stem_multiplier * init_channels
            self.stem = ops.get_op(self.use_stem)(3,
                                                  c_stem,
                                                  stride=stem_stride,
                                                  affine=stem_affine)

        # make cells
        self.cells = nn.ModuleList()
        num_channels = init_channels
        prev_num_channels = c_stem
        """when use preprocess, extra stem is applied after normal stem to make 1st cell input 4c"""
        """currently no matter what stem is used, fp-conv 1x1 is inserted to align the width, maybe modify the stem?(maybe more flops)"""
        self.extra_stem = ops.get_op("nor_conv_1x1")(
            prev_num_channels,
            num_channels * self.micro_search_space.num_steps,
            1,
            affine=True,
        )

        for i, cg in enumerate(self.cell_layout):

            use_next_stage_width = np.array(
                self.macro_search_space.stages_end[:-1]) - 1
            # next_stage_begin_idx = np.array(self.macro_search_space,stages_begin[1:])+1
            use_next_stage_width = (
                i + 2 if i in use_next_stage_width else None
            )  # [i] the last cell; [i+1] the reduction cell; [i+2] the 1st cell next stage

            stride = 2 if cg in self.reduce_cell_groups else 1
            num_channels *= stride

            self.cells.append(
                Layer2MicroDiffCell(
                    prev_num_channels,
                    num_channels,
                    stride,
                    affine=True,
                    primitives=self.micro_search_space.primitives,
                    num_steps=self.micro_search_space.num_steps,
                    num_init_nodes=self.micro_search_space.num_init_nodes,
                    output_op=self.micro_search_space.concat_op,
                    width_choice=self.macro_search_space.width_choice,
                    cell_idx=i,
                    use_next_stage_width=use_next_stage_width,
                    **self.micro_cell_cfg,
                ))

            prev_num_channels = num_channels

        # make pooling and classifier
        self.pooling = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(
            dropout_rate) if dropout_rate else nn.Identity()
        if self.cell_use_postprocess:
            self.classifier = nn.Linear(prev_num_channels, num_classes)
        else:
            self.classifier = nn.Linear(
                prev_num_channels * self.micro_search_space.num_steps,
                num_classes)

        self.total_flops = 0.0

        self.to(self.device)
        self._parallelize()
Esempio n. 23
0
    def __init__(self,
                 search_space,
                 device,
                 genotypes,
                 num_classes=10,
                 init_channels=36,
                 layer_channels=tuple(),
                 stem_multiplier=3,
                 dropout_rate=0.1,
                 dropout_path_rate=0.2,
                 auxiliary_head=False,
                 auxiliary_cfg=None,
                 use_stem="conv_bn_3x3",
                 stem_stride=1,
                 stem_affine=True,
                 no_fc=False,
                 cell_use_preprocess=True,
                 cell_pool_batchnorm=False,
                 cell_group_kwargs=None,
                 cell_independent_conn=False,
                 cell_preprocess_stride="skip_connect",
                 cell_preprocess_normal="relu_conv_bn_1x1",
                 schedule_cfg=None):
        super(CNNGenotypeModel, self).__init__(schedule_cfg)

        self.search_space = search_space
        self.device = device
        assert isinstance(genotypes, str)
        self.genotypes = list(
            genotype_from_str(genotypes, self.search_space)._asdict().values())
        self.genotypes_grouped = list(
            zip([
                group_and_sort_by_to_node(conns)
                for conns in self.genotypes[:self.search_space.num_cell_groups]
            ], self.genotypes[self.search_space.num_cell_groups:]))
        # self.genotypes_grouped = [group_and_sort_by_to_node(g[1]) for g in self.genotypes\
        #                           if "concat" not in g[0]]

        self.num_classes = num_classes
        self.init_channels = init_channels
        self.layer_channels = layer_channels
        self.stem_multiplier = stem_multiplier
        self.use_stem = use_stem
        self.cell_use_preprocess = cell_use_preprocess
        self.cell_group_kwargs = cell_group_kwargs
        self.cell_independent_conn = cell_independent_conn
        self.no_fc = no_fc

        # training
        self.dropout_rate = dropout_rate
        self.dropout_path_rate = dropout_path_rate
        self.auxiliary_head = auxiliary_head

        # search space configs
        self._num_init = self.search_space.num_init_nodes
        self._cell_layout = self.search_space.cell_layout
        self._reduce_cgs = self.search_space.reduce_cell_groups
        self._num_layers = self.search_space.num_layers
        expect(len(self.genotypes_grouped) == self.search_space.num_cell_groups,
               ("Config genotype cell group number({}) "
                "does not match search_space cell group number({})")\
               .format(len(self.genotypes_grouped), self.search_space.num_cell_groups))

        ## initialize sub modules
        if not self.use_stem:
            c_stem = 3
            init_strides = [1] * self._num_init
        elif isinstance(self.use_stem, (list, tuple)):
            self.stems = []
            c_stem = self.stem_multiplier * self.init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stems.append(
                    ops.get_op(stem_type)(c_in,
                                          c_stem,
                                          stride=stem_stride,
                                          affine=stem_affine))
            self.stems = nn.ModuleList(self.stems)
            init_strides = [stem_stride] * self._num_init
        else:
            c_stem = self.stem_multiplier * self.init_channels
            self.stem = ops.get_op(self.use_stem)(3,
                                                  c_stem,
                                                  stride=stem_stride,
                                                  affine=stem_affine)
            init_strides = [1] * self._num_init

        self.cells = nn.ModuleList()
        num_channels = self.init_channels
        prev_num_channels = [c_stem] * self._num_init
        strides = [
            2 if self._is_reduce(i_layer) else 1
            for i_layer in range(self._num_layers)
        ]
        if self.layer_channels:
            expect(len(self.layer_channels) == len(strides) + 1,
                   ("Config cell channels({}) does not match search_space num layers + 1 ({})"\
                    .format(len(self.layer_channels), self.search_space.num_layers + 1)),
                   ConfigException)
        for i_layer, stride in enumerate(strides):
            if self.layer_channels:
                # input and output channels of every layer is specified
                num_channels = self.layer_channels[i_layer]
                num_out_channels = self.layer_channels[i_layer + 1]
            else:
                if stride > 1:
                    num_channels *= stride
                num_out_channels = num_channels
            if cell_group_kwargs is not None:
                # support passing in different kwargs when instantializing
                # cell class for different cell groups
                # Can specificy input/output channels by hand in configuration,
                # instead of relying on the default
                # "whenever stride/2, channelx2 and mapping with preprocess operations" assumption
                kwargs = {
                    k: v
                    for k, v in cell_group_kwargs[
                        self._cell_layout[i_layer]].items()
                }
                if "C_in" in kwargs:
                    num_channels = kwargs.pop("C_in")
                if "C_out" in kwargs:
                    num_out_channels = kwargs.pop("C_out")
            else:
                kwargs = {}
            cg_idx = self.search_space.cell_layout[i_layer]

            cell = CNNGenotypeCell(self.search_space,
                                   self.genotypes_grouped[cg_idx],
                                   layer_index=i_layer,
                                   num_channels=num_channels,
                                   num_out_channels=num_out_channels,
                                   prev_num_channels=tuple(prev_num_channels),
                                   stride=stride,
                                   prev_strides=init_strides +
                                   strides[:i_layer],
                                   use_preprocess=cell_use_preprocess,
                                   pool_batchnorm=cell_pool_batchnorm,
                                   independent_conn=cell_independent_conn,
                                   preprocess_stride=cell_preprocess_stride,
                                   preprocess_normal=cell_preprocess_normal,
                                   **kwargs)
            # TODO: support specify concat explicitly
            prev_num_channel = cell.num_out_channel()
            prev_num_channels.append(prev_num_channel)
            prev_num_channels = prev_num_channels[1:]
            self.cells.append(cell)

            if i_layer == (2 * self._num_layers) // 3 and self.auxiliary_head:
                if auxiliary_head == "imagenet":
                    self.auxiliary_net = AuxiliaryHeadImageNet(
                        prev_num_channels[-1], num_classes,
                        **(auxiliary_cfg or {}))
                else:
                    self.auxiliary_net = AuxiliaryHead(prev_num_channels[-1],
                                                       num_classes,
                                                       **(auxiliary_cfg or {}))

        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        if self.dropout_rate and self.dropout_rate > 0:
            self.dropout = nn.Dropout(p=self.dropout_rate)
        else:
            self.dropout = ops.Identity()
        if self.no_fc:
            self.classifier = ops.Identity()
        else:
            self.classifier = nn.Linear(prev_num_channels[-1],
                                        self.num_classes)
        self.to(self.device)

        # for flops calculation
        self.total_flops = 0
        self._flops_calculated = False
        self.set_hook()
Esempio n. 24
0
    def __init__(
        self,
        search_space,  # type: Layer2SearchSpace
        device,
        rollout_type="layer2",
        init_channels=16,
        # classifier
        num_classes=10,
        dropout_rate=0.0,
        max_grad_norm=None,
        # stem
        use_stem="conv_bn_3x3",
        stem_stride=1,
        stem_affine=True,
        stem_multiplier=1,
        # candidate
        candidate_eval_no_grad=True,
        # schedule
        schedule_cfg=None,
    ):
        super().__init__(search_space, device, rollout_type, schedule_cfg)
        nn.Module.__init__(self)

        self.macro_search_space = (search_space.macro_search_space
                                   )  # type: StagewiseMacroSearchSpace
        self.micro_search_space = (search_space.micro_search_space
                                   )  # type: DenseMicroSearchSpace

        self.num_cell_groups = self.macro_search_space.num_cell_groups
        self.cell_layout = self.macro_search_space.cell_layout
        self.reduce_cell_groups = self.macro_search_space.reduce_cell_groups

        self.max_grad_norm = max_grad_norm

        self.candidate_eval_no_grad = candidate_eval_no_grad

        # make stem
        self.use_stem = use_stem
        if not self.use_stem:
            c_stem = 3
        elif isinstance(self.use_stem, (list, tuple)):
            self.stem = []
            c_stem = stem_multiplier * init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stem.append(
                    ops.get_op(stem_type)(c_in,
                                          c_stem,
                                          stride=stem_stride,
                                          affine=stem_affine))

            self.stem = nn.Sequential(*self.stem)
        else:
            c_stem = stem_multiplier * init_channels
            self.stem = ops.get_op(self.use_stem)(3,
                                                  c_stem,
                                                  stride=stem_stride,
                                                  affine=stem_affine)

        # make cells
        self.cells = nn.ModuleList()
        num_channels = init_channels
        prev_num_channels = c_stem

        for i, cg in enumerate(self.cell_layout):
            stride = 2 if cg in self.reduce_cell_groups else 1
            num_channels *= stride

            self.cells.append(
                Layer2MicroCell(
                    prev_num_channels,
                    num_channels,
                    stride,
                    affine=True,
                    primitives=self.micro_search_space.primitives,
                    num_steps=self.micro_search_space.num_steps,
                    num_init_nodes=self.micro_search_space.num_init_nodes,
                    output_op=self.micro_search_space.concat_op,
                    postprocess_op="conv_1x1",
                    cell_shortcut=True,
                    cell_shortcut_op="skip_connect",
                ))

            prev_num_channels = num_channels

        # make pooling and classifier
        self.pooling = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(
            dropout_rate) if dropout_rate else nn.Identity()
        self.classifier = nn.Linear(prev_num_channels, num_classes)

        self.to(self.device)
Esempio n. 25
0
    def __init__(self, search_space, genotype_grouped, layer_index,
                 num_channels, num_out_channels, prev_num_channels, stride,
                 prev_strides, use_preprocess, pool_batchnorm,
                 independent_conn, preprocess_stride, preprocess_normal,
                 **op_kwargs):
        super(CNNGenotypeCell, self).__init__()
        self.search_space = search_space
        self.conns_grouped, self.concat_nodes = genotype_grouped
        self.stride = stride
        self.is_reduce = stride != 1
        self.num_channels = num_channels
        self.num_out_channels = num_out_channels
        self.layer_index = layer_index
        self.use_preprocess = use_preprocess
        self.pool_batchnorm = pool_batchnorm
        self.independent_conn = independent_conn
        self.op_kwargs = op_kwargs

        self._steps = self.search_space.get_layer_num_steps(layer_index)
        self._num_init = self.search_space.num_init_nodes
        self._primitives = self.search_space.shared_primitives

        # initialize self.concat_op, self._out_multiplier (only used for discrete super net)
        self.concat_op = ops.get_concat_op(self.search_space.concat_op)
        if not self.concat_op.is_elementwise:
            expect(
                not self.search_space.loose_end,
                "For shared weights weights manager, when non-elementwise concat op do not "
                "support loose-end search space")
            self._out_multipler = self._steps if not self.search_space.concat_nodes \
                                  else len(self.search_space.concat_nodes)
        else:
            # elementwise concat op. e.g. sum, mean
            self._out_multipler = 1

        self.preprocess_ops = nn.ModuleList()
        prev_strides = list(np.cumprod(list(reversed(prev_strides))))
        prev_strides.insert(0, 1)
        prev_strides = reversed(prev_strides[:len(prev_num_channels)])
        for prev_c, prev_s in zip(prev_num_channels, prev_strides):
            # print("cin: {}, cout: {}, stride: {}".format(prev_c, num_channels, prev_s))
            if not self.use_preprocess:
                # stride/channel not handled!
                self.preprocess_ops.append(ops.Identity())
                continue
            if prev_s > 1:
                # need skip connection, and is not the connection from the input image
                # ops.FactorizedReduce(C_in=prev_c,
                preprocess = ops.get_op(preprocess_stride)(C=prev_c,
                                                           C_out=num_channels,
                                                           stride=prev_s,
                                                           affine=True)
            else:  # prev_c == _steps * num_channels or inputs
                preprocess = ops.get_op(preprocess_normal)(C=prev_c,
                                                           C_out=num_channels,
                                                           stride=1,
                                                           affine=True)
            self.preprocess_ops.append(preprocess)
        assert len(self.preprocess_ops) == self._num_init

        self.edges = _defaultdict_3()
        self.edge_mod = torch.nn.Module(
        )  # a stub wrapping module of all the edges
        for _, conns in self.conns_grouped:
            for op_type, from_, to_ in conns:
                stride = self.stride if from_ < self._num_init else 1
                op = ops.get_op(op_type)(num_channels, num_out_channels,
                                         stride, True, **op_kwargs)
                if self.pool_batchnorm and "pool" in op_type:
                    op = nn.Sequential(
                        op, nn.BatchNorm2d(num_out_channels, affine=False))
                index = len(self.edges[from_][to_][op_type])
                if index == 0 or self.independent_conn:
                    # there is no this connection already established,
                    # or use indepdent connection for exactly the same (from, to, op_type)
                    self.edges[from_][to_][op_type][index] = op
                    self.edge_mod.add_module(
                        "f_{}_t_{}-{}-{}".format(from_, to_, op_type, index),
                        op)

        self._edge_name_pattern = re.compile(
            "f_([0-9]+)_t_([0-9]+)-([a-z0-9_-]+)-([0-9])")
Esempio n. 26
0
    def __init__(
        self,
        search_space,  # layer2
        device,
        genotypes,  # layer2
        micro_model_type="micro-dense-model",
        micro_model_cfg={},
        num_classes=10,
        init_channels=36,
        stem_multiplier=1,
        dropout_rate=0.0,
        dropout_path_rate=0.0,
        use_stem="conv_bn_3x3",
        stem_stride=1,
        stem_affine=True,
        auxiliary_head=False,
        auxiliary_cfg=None,
        schedule_cfg=None,
    ):
        super(MacroStagewiseFinalModel, self).__init__(schedule_cfg)

        self.macro_ss = search_space.macro_search_space
        self.micro_ss = search_space.micro_search_space
        self.device = device
        assert isinstance(genotypes, str)
        self.genotypes_str = genotypes
        self.macro_g, self.micro_g = genotype_from_str(genotypes, search_space)

        # micro model (cell) class
        micro_model_cls = FinalModel.get_class_(micro_model_type)  # cell type

        self.num_classes = num_classes
        self.init_channels = init_channels
        self.stem_multiplier = stem_multiplier
        self.stem_stride = stem_stride
        self.stem_affine = stem_affine
        self.use_stem = use_stem

        # training
        self.dropout_rate = dropout_rate
        self.dropout_path_rate = dropout_path_rate

        self.auxiliary_head = auxiliary_head

        self.overall_adj = self.macro_ss.parse_overall_adj(self.macro_g)
        self.layer_widths = [float(w) for w in self.macro_g.width.split(",")]

        # sort channels out
        assert self.stem_multiplier == 1, "Cannot handle stem_multiplier != 1 now"
        self.input_channel_list = [self.init_channels]
        for i in range(1, self.macro_ss.num_layers):
            self.input_channel_list.append(
                self.input_channel_list[i - 1] *
                2 if self._is_reduce(i - 1) else self.input_channel_list[i -
                                                                         1])
        for i in range(self.macro_ss.num_layers):
            self.input_channel_list[i] = int(
                self.input_channel_list[i] * self.layer_widths[i]
                if not self._is_reduce(i) else self.input_channel_list[i] *
                self.layer_widths[i - 1])

        self.output_channel_list = self.input_channel_list[1:] + [
            self.input_channel_list[-1]
        ]

        # construct cells
        if not self.use_stem:
            raise NotImplementedError
            c_stem = 3
        elif isinstance(self.use_stem, (list, tuple)):
            raise NotImplementedError
            self.stems = []
            c_stem = self.stem_multiplier * self.init_channels
            for i, stem_type in enumerate(self.use_stem):
                c_in = 3 if i == 0 else c_stem
                self.stems.append(
                    ops.get_op(stem_type)(c_in,
                                          c_stem,
                                          stride=stem_stride,
                                          affine=stem_affine))
            self.stem = nn.Sequential(self.stems)
        else:
            self.stem = ops.get_op(self.use_stem)(3,
                                                  self.input_channel_list[0],
                                                  stride=stem_stride,
                                                  affine=stem_affine)

        self.cells = nn.ModuleList()
        self.micro_arch_list = self.micro_ss.rollout_from_genotype(
            self.micro_g).arch
        for i_layer in range(self.macro_ss.num_layers):
            # print(i_layer, self._is_reduce(i_layer))
            stride = 2 if self._is_reduce(i_layer) else 1
            cg_idx = self.macro_ss.cell_layout[i_layer]
            # contruct micro cell
            # FIXME: Currently MacroStageWiseFinalModel doesnot support postprocess = False
            micro_model_cfg["postprocess"] = True
            cell = micro_model_cls(
                self.micro_ss,
                self.micro_arch_list[cg_idx],
                num_input_channels=self.input_channel_list[i_layer],
                num_out_channels=self.output_channel_list[i_layer],
                stride=stride,
                **micro_model_cfg)
            # assume non-reduce cell does not change channel number
            self.cells.append(cell)
            # add auxiliary head
            if i_layer == (
                    2 * self.macro_ss.num_layers) // 3 and self.auxiliary_head:
                if auxiliary_head == "imagenet":
                    self.auxiliary_net = AuxiliaryHeadImageNet(
                        self.output_channel_list[i_layer], num_classes,
                        **(auxiliary_cfg or {}))
                else:
                    self.auxiliary_net = AuxiliaryHead(
                        self.output_channel_list[i_layer], num_classes,
                        **(auxiliary_cfg or {}))

        self.lastact = nn.Identity()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        if self.dropout_rate and self.dropout_rate > 0:
            self.dropout = nn.Dropout(p=self.dropout_rate)
        else:
            self.dropout = ops.Identity()
        self.classifier = nn.Linear(self.output_channel_list[-1],
                                    self.num_classes)
        self.to(self.device)

        # for flops calculation
        self.total_flops = 0
        self._flops_calculated = False
        self._set_hook()