Exemple #1
0
 def __init__(self,
              C_in,
              C_out,
              kernel_size,
              stride,
              padding,
              norm_layer,
              affine=True,
              input_size=None):
     super(SepConvHeavy, self).__init__()
     self.op = nn.Sequential(
         # depth wise
         Conv2d(C_in,
                C_in,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=C_in,
                bias=False),
         # point wise
         Conv2d(C_in,
                C_in,
                kernel_size=1,
                padding=0,
                bias=False,
                norm=get_norm(norm_layer, C_in),
                activation=nn.ReLU()),
         # stack 2 separate depthwise-conv.
         Conv2d(C_in,
                C_in,
                kernel_size=kernel_size,
                stride=1,
                padding=padding,
                groups=C_in,
                bias=False),
         Conv2d(C_in,
                C_in,
                kernel_size=1,
                padding=0,
                bias=False,
                norm=get_norm(norm_layer, C_in),
                activation=nn.ReLU()),
         # stack 3 separate depthwise-conv.
         Conv2d(C_in,
                C_in,
                kernel_size=kernel_size,
                stride=1,
                padding=padding,
                groups=C_in,
                bias=False),
         Conv2d(C_in,
                C_out,
                kernel_size=1,
                padding=0,
                bias=False,
                norm=get_norm(norm_layer, C_out)))
     self.flops = self.get_flop([kernel_size, kernel_size], stride, C_in,
                                C_out, affine, input_size[0], input_size[1])
     # using Kaiming init
     weight_init.kaiming_init_module(self.op, mode='fan_in')
Exemple #2
0
 def __init__(self,
              C_in,
              C_out,
              kernel_size,
              stride,
              padding,
              dilation,
              norm_layer,
              affine=True,
              input_size=None):
     super(DilConv, self).__init__()
     self.op = nn.Sequential(
         Conv2d(C_in,
                C_in,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=C_in,
                bias=False),
         Conv2d(C_in,
                C_out,
                kernel_size=1,
                padding=0,
                bias=False,
                norm=get_norm(norm_layer, C_out)))
     self.flops = self.get_flop([kernel_size, kernel_size], stride, C_in,
                                C_out, affine, input_size[0], input_size[1])
     # using Kaiming init
     weight_init.kaiming_init_module(self.op, mode='fan_in')
Exemple #3
0
 def __init__(self,
              C_in,
              C_out,
              kernel_size,
              stride,
              padding,
              norm_layer,
              expansion=4,
              affine=True,
              input_size=None):
     super(MBConv, self).__init__()
     self.hidden_dim = expansion * C_in
     self.op = nn.Sequential(
         # pw
         Conv2d(C_in,
                self.hidden_dim,
                1,
                1,
                0,
                bias=False,
                norm=get_norm(norm_layer, self.hidden_dim),
                activation=nn.ReLU()),
         # dw
         Conv2d(self.hidden_dim,
                self.hidden_dim,
                kernel_size,
                stride,
                padding,
                groups=self.hidden_dim,
                bias=False,
                norm=get_norm(norm_layer, self.hidden_dim),
                activation=nn.ReLU()),
         # pw-linear without ReLU!
         Conv2d(self.hidden_dim,
                C_out,
                1,
                1,
                0,
                bias=False,
                norm=get_norm(norm_layer, C_out)))
     self.flops = self.get_flop([kernel_size, kernel_size], stride, C_in,
                                C_out, affine, input_size[0], input_size[1])
     # using Kaiming init
     weight_init.kaiming_init_module(self.op, mode='fan_in')
Exemple #4
0
 def __init__(self,
              C_in,
              C_out,
              kernel_size,
              stride,
              padding,
              norm_layer,
              expansion=4,
              affine=True,
              input_size=None):
     super(Bottleneck, self).__init__()
     self.hidden_dim = C_in // expansion
     self.op = nn.Sequential(
         Conv2d(C_in,
                self.hidden_dim,
                kernel_size=1,
                padding=0,
                bias=False,
                norm=get_norm(norm_layer, self.hidden_dim),
                activation=nn.ReLU()),
         Conv2d(self.hidden_dim,
                self.hidden_dim,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False,
                norm=get_norm(norm_layer, self.hidden_dim),
                activation=nn.ReLU()),
         Conv2d(self.hidden_dim,
                C_out,
                kernel_size=1,
                padding=0,
                bias=False,
                norm=get_norm(norm_layer, C_out)))
     self.flops = self.get_flop([kernel_size, kernel_size], stride, C_in,
                                C_out, affine, input_size[0], input_size[1])
     # using Kaiming init
     weight_init.kaiming_init_module(self.op, mode='fan_in')
Exemple #5
0
 def __init__(self, C_in, C_out, norm_layer, affine=True, input_size=None):
     super(Identity, self).__init__()
     if C_in == C_out:
         self.change = False
         self.flops = 0.0
     else:
         self.change = True
         self.op = Conv2d(C_in,
                          C_out,
                          kernel_size=1,
                          padding=0,
                          bias=False,
                          norm=get_norm(norm_layer, C_out))
         self.flops = self.get_flop([1, 1], 1, C_in, C_out, affine,
                                    input_size[0], input_size[1])
         # using Kaiming init
         weight_init.kaiming_init_module(self.op, mode='fan_in')
Exemple #6
0
 def __init__(self, in_channels=3, out_channels=64, norm="BN"):
     """
     Args:
         norm (str or callable): a callable that takes the number of
             channels and return a `nn.Module`, or a pre-defined string
             (one of {"FrozenBN", "BN", "GN"}).
     """
     super().__init__()
     self.conv1 = Conv2d(
         in_channels,
         out_channels,
         kernel_size=7,
         stride=2,
         padding=3,
         bias=False,
         norm=get_norm(norm, out_channels),
     )
     weight_init.c2_msra_fill(self.conv1)
Exemple #7
0
 def __init__(self,
              C_in,
              C_out,
              kernel_size,
              stride,
              padding,
              norm_layer,
              affine=True,
              input_size=None):
     super(BasicResBlock, self).__init__()
     self.op = Conv2d(C_in,
                      C_out,
                      kernel_size,
                      stride=stride,
                      padding=padding,
                      bias=False,
                      norm=get_norm(norm_layer, C_out))
     self.flops = self.get_flop([kernel_size, kernel_size], stride, C_in,
                                C_out, affine, input_size[0], input_size[1])
     # using Kaiming init
     weight_init.kaiming_init_module(self.op, mode='fan_in')
    def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
        super().__init__()
        # fmt: off
        self.in_features = cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
        feature_strides = {k: v.stride for k, v in input_shape.items()}  # noqa:F841
        feature_channels = {k: v.channels for k, v in input_shape.items()}
        feature_resolution = {
            k: np.array([v.height, v.width])
            for k, v in input_shape.items()
        }
        self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
        num_classes = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
        norm = cfg.MODEL.SEM_SEG_HEAD.NORM
        self.loss_weight = cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT
        self.cal_flops = cfg.MODEL.CAL_FLOPS
        self.real_flops = 0.0
        # fmt: on

        self.layer_decoder_list = nn.ModuleList()
        # set affine in BatchNorm
        if 'Sync' in norm:
            affine = True
        else:
            affine = False
        # use simple decoder
        for _feat in self.in_features:
            res_size = feature_resolution[_feat]
            in_channel = feature_channels[_feat]
            if _feat == 'layer_0':
                out_channel = in_channel
            else:
                out_channel = in_channel // 2
            conv_1x1 = Conv2d(in_channel,
                              out_channel,
                              kernel_size=1,
                              stride=1,
                              padding=0,
                              bias=False,
                              norm=get_norm(norm, out_channel),
                              activation=nn.ReLU())
            self.real_flops += cal_op_flops.count_ConvBNReLU_flop(
                res_size[0],
                res_size[1],
                in_channel,
                out_channel, [1, 1],
                is_affine=affine)
            self.layer_decoder_list.append(conv_1x1)
        # using Kaiming init
        for layer in self.layer_decoder_list:
            weight_init.kaiming_init_module(layer, mode='fan_in')
        in_channel = feature_channels['layer_0']
        # the output layer
        self.predictor = Conv2d(in_channels=in_channel,
                                out_channels=num_classes,
                                kernel_size=3,
                                stride=1,
                                padding=1)
        self.real_flops += cal_op_flops.count_Conv_flop(
            feature_resolution['layer_0'][0], feature_resolution['layer_0'][1],
            in_channel, num_classes, [3, 3])
        # using Kaiming init
        weight_init.kaiming_init_module(self.predictor, mode='fan_in')
Exemple #9
0
    def __init__(self,
                 bottom_up,
                 in_features,
                 out_channels,
                 norm="",
                 top_block=None,
                 fuse_type="sum"):
        """
        Args:
            bottom_up (Backbone): module representing the bottom up subnetwork.
                Must be a subclass of :class:`Backbone`. The multi-scale feature
                maps generated by the bottom up network, and listed in `in_features`,
                are used to generate FPN levels.
            in_features (list[str]): names of the input feature maps coming
                from the backbone to which FPN is attached. For example, if the
                backbone produces ["res2", "res3", "res4"], any *contiguous* sublist
                of these may be used; order must be from high to low resolution.
            out_channels (int): number of channels in the output feature maps.
            norm (str): the normalization to use.
            top_block (nn.Module or None): if provided, an extra operation will
                be performed on the output of the last (smallest resolution)
                FPN output, and the result will extend the result list. The top_block
                further downsamples the feature map. It must have an attribute
                "num_levels", meaning the number of extra FPN levels added by
                this block, and "in_feature", which is a string representing
                its input feature (e.g., p5).
            fuse_type (str): types for fusing the top down features and the lateral
                ones. It can be "sum" (default), which sums up element-wise; or "avg",
                which takes the element-wise mean of the two.
        """
        super(FPN, self).__init__()
        assert isinstance(bottom_up, Backbone)

        # Feature map strides and channels from the bottom up network (e.g. ResNet)
        in_strides = [bottom_up.out_feature_strides[f] for f in in_features]
        in_channels = [bottom_up.out_feature_channels[f] for f in in_features]

        _assert_strides_are_log2_contiguous(in_strides)
        lateral_convs = []
        output_convs = []

        use_bias = norm == ""
        for idx, in_channels in enumerate(in_channels):
            lateral_norm = get_norm(norm, out_channels)
            output_norm = get_norm(norm, out_channels)

            lateral_conv = Conv2d(in_channels,
                                  out_channels,
                                  kernel_size=1,
                                  bias=use_bias,
                                  norm=lateral_norm)
            output_conv = Conv2d(
                out_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=use_bias,
                norm=output_norm,
            )
            weight_init.c2_xavier_fill(lateral_conv)
            weight_init.c2_xavier_fill(output_conv)
            stage = int(math.log2(in_strides[idx]))
            self.add_module("fpn_lateral{}".format(stage), lateral_conv)
            self.add_module("fpn_output{}".format(stage), output_conv)

            lateral_convs.append(lateral_conv)
            output_convs.append(output_conv)
        # Place convs into top-down order (from low to high resolution)
        # to make the top-down computation in forward clearer.
        self.lateral_convs = lateral_convs[::-1]
        self.output_convs = output_convs[::-1]
        self.top_block = top_block
        self.in_features = in_features
        self.bottom_up = bottom_up
        # Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
        self._out_feature_strides = {
            "p{}".format(int(math.log2(s))): s
            for s in in_strides
        }
        # top block output feature maps.
        if self.top_block is not None:
            for s in range(stage, stage + self.top_block.num_levels):
                self._out_feature_strides["p{}".format(s + 1)] = 2**(s + 1)

        self._out_features = list(self._out_feature_strides.keys())
        self._out_feature_channels = {
            k: out_channels
            for k in self._out_features
        }
        self._size_divisibility = in_strides[-1]
        assert fuse_type in {"avg", "sum"}
        self._fuse_type = fuse_type
Exemple #10
0
    def __init__(self,
                 in_channels=3,
                 mid_channels=64,
                 out_channels=64,
                 input_res=None,
                 sept_stem=True,
                 norm="BN",
                 affine=True):
        """
        Build basic STEM for Dynamic Network.
        Args:
            norm (str or callable): a callable that takes the number of
                channels and return a `nn.Module`, or a pre-defined string
                (one of {"FrozenBN", "BN", "GN"}).
        """
        super().__init__()

        self.real_flops = 0.0
        # start with 3 stem layers down-sampling by 4.
        self.stem_1 = Conv2d(in_channels,
                             mid_channels,
                             kernel_size=3,
                             stride=2,
                             bias=False,
                             norm=get_norm(norm, mid_channels),
                             activation=nn.ReLU())
        self.real_flops += cal_op_flops.count_ConvBNReLU_flop(input_res[0],
                                                              input_res[1],
                                                              3,
                                                              mid_channels,
                                                              [3, 3],
                                                              stride=2,
                                                              is_affine=affine)
        # stem 2
        input_res = input_res // 2
        if not sept_stem:
            self.stem_2 = Conv2d(mid_channels,
                                 mid_channels,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1,
                                 bias=False,
                                 norm=get_norm(norm, mid_channels),
                                 activation=nn.ReLU())
            self.real_flops += cal_op_flops.count_ConvBNReLU_flop(
                input_res[0],
                input_res[1],
                mid_channels,
                mid_channels, [3, 3],
                is_affine=affine)
        else:
            self.stem_2 = nn.Sequential(
                Conv2d(mid_channels,
                       mid_channels,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       groups=mid_channels,
                       bias=False),
                Conv2d(mid_channels,
                       mid_channels,
                       kernel_size=1,
                       stride=1,
                       padding=0,
                       bias=False,
                       norm=get_norm(norm, mid_channels),
                       activation=nn.ReLU()))
            self.real_flops += (
                cal_op_flops.count_Conv_flop(input_res[0],
                                             input_res[1],
                                             mid_channels,
                                             mid_channels, [3, 3],
                                             groups=mid_channels) +
                cal_op_flops.count_ConvBNReLU_flop(input_res[0],
                                                   input_res[1],
                                                   mid_channels,
                                                   mid_channels, [1, 1],
                                                   is_affine=affine))
        # stem 3
        if not sept_stem:
            self.stem_3 = Conv2d(mid_channels,
                                 out_channels,
                                 kernel_size=3,
                                 stride=2,
                                 padding=1,
                                 bias=False,
                                 norm=get_norm(norm, out_channels),
                                 activation=nn.ReLU())
            self.real_flops += cal_op_flops.count_ConvBNReLU_flop(
                input_res[0],
                input_res[1],
                mid_channels,
                out_channels, [3, 3],
                stride=2,
                is_affine=affine)
        else:
            self.stem_3 = nn.Sequential(
                Conv2d(mid_channels,
                       mid_channels,
                       kernel_size=3,
                       stride=2,
                       padding=1,
                       groups=mid_channels,
                       bias=False),
                Conv2d(mid_channels,
                       out_channels,
                       kernel_size=1,
                       padding=0,
                       bias=False,
                       norm=get_norm(norm, out_channels),
                       activation=nn.ReLU()))
            self.real_flops += (
                cal_op_flops.count_Conv_flop(input_res[0],
                                             input_res[1],
                                             mid_channels,
                                             mid_channels, [3, 3],
                                             stride=2,
                                             groups=mid_channels) +
                cal_op_flops.count_ConvBNReLU_flop(input_res[0] // 2,
                                                   input_res[1] // 2,
                                                   mid_channels,
                                                   out_channels, [1, 1],
                                                   is_affine=affine))
        self.out_res = input_res // 2
        self.out_cha = out_channels
        # using Kaiming init
        for layer in [self.stem_1, self.stem_2, self.stem_3]:
            weight_init.kaiming_init_module(layer, mode='fan_in')
Exemple #11
0
    def __init__(
        self,
        in_channels,
        out_channels,
        *,
        bottleneck_channels,
        stride=1,
        num_groups=1,
        norm="BN",
        stride_in_1x1=False,
        dilation=1,
    ):
        """
        Args:
            norm (str or callable): a callable that takes the number of
                channels and return a `nn.Module`, or a pre-defined string
                (one of {"FrozenBN", "BN", "GN"}).
            stride_in_1x1 (bool): when stride==2, whether to put stride in the
                first 1x1 convolution or the bottleneck 3x3 convolution.
        """
        super().__init__(in_channels, out_channels, stride)

        if in_channels != out_channels:
            self.shortcut = Conv2d(
                in_channels,
                out_channels,
                kernel_size=1,
                stride=stride,
                bias=False,
                norm=get_norm(norm, out_channels),
            )
        else:
            self.shortcut = None

        # The original MSRA ResNet models have stride in the first 1x1 conv
        # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
        # stride in the 3x3 conv
        stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)

        self.conv1 = Conv2d(
            in_channels,
            bottleneck_channels,
            kernel_size=1,
            stride=stride_1x1,
            bias=False,
            norm=get_norm(norm, bottleneck_channels),
        )

        self.conv2 = Conv2d(
            bottleneck_channels,
            bottleneck_channels,
            kernel_size=3,
            stride=stride_3x3,
            padding=1 * dilation,
            bias=False,
            groups=num_groups,
            dilation=dilation,
            norm=get_norm(norm, bottleneck_channels),
        )

        self.conv3 = Conv2d(
            bottleneck_channels,
            out_channels,
            kernel_size=1,
            bias=False,
            norm=get_norm(norm, out_channels),
        )

        for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
            if layer is not None:  # shortcut can be None
                weight_init.c2_msra_fill(layer)
Exemple #12
0
    def __init__(
        self,
        in_channels,
        out_channels,
        *,
        bottleneck_channels,
        stride=1,
        num_groups=1,
        norm="BN",
        stride_in_1x1=False,
        dilation=1,
        deform_modulated=False,
        deform_num_groups=1,
    ):
        """
        Similar to :class:`BottleneckBlock`, but with deformable conv in the 3x3 convolution.
        """
        super().__init__(in_channels, out_channels, stride)
        self.deform_modulated = deform_modulated

        if in_channels != out_channels:
            self.shortcut = Conv2d(
                in_channels,
                out_channels,
                kernel_size=1,
                stride=stride,
                bias=False,
                norm=get_norm(norm, out_channels),
            )
        else:
            self.shortcut = None

        stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)

        self.conv1 = Conv2d(
            in_channels,
            bottleneck_channels,
            kernel_size=1,
            stride=stride_1x1,
            bias=False,
            norm=get_norm(norm, bottleneck_channels),
        )

        if deform_modulated:
            deform_conv_op = ModulatedDeformConv
            # offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size
            offset_channels = 27
        else:
            deform_conv_op = DeformConv
            offset_channels = 18

        self.conv2_offset = Conv2d(
            bottleneck_channels,
            offset_channels * deform_num_groups,
            kernel_size=3,
            stride=stride_3x3,
            padding=1 * dilation,
            dilation=dilation,
        )
        self.conv2 = deform_conv_op(
            bottleneck_channels,
            bottleneck_channels,
            kernel_size=3,
            stride=stride_3x3,
            padding=1 * dilation,
            bias=False,
            groups=num_groups,
            dilation=dilation,
            deformable_groups=deform_num_groups,
            norm=get_norm(norm, bottleneck_channels),
        )

        self.conv3 = Conv2d(
            bottleneck_channels,
            out_channels,
            kernel_size=1,
            bias=False,
            norm=get_norm(norm, out_channels),
        )

        for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
            if layer is not None:  # shortcut can be None
                weight_init.c2_msra_fill(layer)

        nn.init.constant_(self.conv2_offset.weight, 0)
        nn.init.constant_(self.conv2_offset.bias, 0)
    def __init__(self,
                 C_in,
                 C_out,
                 norm,
                 allow_up,
                 allow_down,
                 input_size,
                 cell_type,
                 cal_flops=True,
                 using_gate=False,
                 small_gate=False,
                 gate_bias=1.5,
                 affine=True):
        super(Cell, self).__init__()
        self.channel_in = C_in
        self.channel_out = C_out
        self.allow_up = allow_up
        self.allow_down = allow_down
        self.cal_flops = cal_flops
        self.using_gate = using_gate
        self.small_gate = small_gate

        self.cell_ops = Mixed_OP(inplanes=self.channel_in,
                                 outplanes=self.channel_out,
                                 stride=1,
                                 cell_type=cell_type,
                                 norm=norm,
                                 affine=affine,
                                 input_size=input_size)
        self.cell_flops = self.cell_ops.flops
        # resolution keep
        self.res_keep = nn.ReLU()
        self.res_keep_flops = cal_op_flops.count_ReLU_flop(
            input_size[0], input_size[1], self.channel_out)
        # resolution up and dim down
        if self.allow_up:
            self.res_up = nn.Sequential(
                nn.ReLU(),
                Conv2d(self.channel_out,
                       self.channel_out // 2,
                       kernel_size=1,
                       stride=1,
                       padding=0,
                       bias=False,
                       norm=get_norm(norm, self.channel_out // 2),
                       activation=nn.ReLU()))
            # calculate Flops
            self.res_up_flops = cal_op_flops.count_ReLU_flop(
                input_size[0], input_size[1],
                self.channel_out) + cal_op_flops.count_ConvBNReLU_flop(
                    input_size[0],
                    input_size[1],
                    self.channel_out,
                    self.channel_out // 2, [1, 1],
                    is_affine=affine)
            # using Kaiming init
            weight_init.kaiming_init_module(self.res_up, mode='fan_in')
        # resolution down and dim up
        if self.allow_down:
            self.res_down = nn.Sequential(
                nn.ReLU(),
                Conv2d(self.channel_out,
                       2 * self.channel_out,
                       kernel_size=1,
                       stride=2,
                       padding=0,
                       bias=False,
                       norm=get_norm(norm, 2 * self.channel_out),
                       activation=nn.ReLU()))
            # calculate Flops
            self.res_down_flops = cal_op_flops.count_ReLU_flop(
                input_size[0], input_size[1],
                self.channel_out) + cal_op_flops.count_ConvBNReLU_flop(
                    input_size[0],
                    input_size[1],
                    self.channel_out,
                    2 * self.channel_out, [1, 1],
                    stride=2,
                    is_affine=affine)
            # using Kaiming init
            weight_init.kaiming_init_module(self.res_down, mode='fan_in')
        if self.allow_up and self.allow_down:
            self.gate_num = 3
        elif self.allow_up or self.allow_down:
            self.gate_num = 2
        else:
            self.gate_num = 1
        if self.using_gate:
            self.gate_conv_beta = nn.Sequential(
                Conv2d(self.channel_in,
                       self.channel_in // 2,
                       kernel_size=1,
                       stride=1,
                       padding=0,
                       bias=False,
                       norm=get_norm(norm, self.channel_in // 2),
                       activation=nn.ReLU()), nn.AdaptiveAvgPool2d((1, 1)),
                Conv2d(self.channel_in // 2,
                       self.gate_num,
                       kernel_size=1,
                       stride=1,
                       padding=0,
                       bias=True))
            if self.small_gate:
                input_size = input_size // 4
            self.gate_flops = cal_op_flops.count_ConvBNReLU_flop(
                input_size[0],
                input_size[1],
                self.channel_in,
                self.channel_in // 2, [1, 1],
                is_affine=affine) + cal_op_flops.count_Pool2d_flop(
                    input_size[0], input_size[1], self.channel_in // 2,
                    [1, 1], 1) + cal_op_flops.count_Conv_flop(
                        1, 1, self.channel_in // 2, self.gate_num, [1, 1])
            # using Kaiming init and predefined bias for gate
            weight_init.kaiming_init_module(self.gate_conv_beta,
                                            mode='fan_in',
                                            bias=gate_bias)
        else:
            self.register_buffer('gate_weights_beta',
                                 torch.ones(1, self.gate_num, 1, 1).cuda())
            self.gate_flops = 0.0